tellurion commited on
Commit
d066167
·
0 Parent(s):

initialize huggingface space demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +227 -0
  2. backend/__init__.py +16 -0
  3. backend/appfunc.py +298 -0
  4. backend/functool.py +276 -0
  5. backend/style.css +181 -0
  6. configs/inference/sdxl.yaml +88 -0
  7. configs/inference/xlv2.yaml +108 -0
  8. configs/scheduler_cfgs/ddim.yaml +10 -0
  9. configs/scheduler_cfgs/dpm.yaml +8 -0
  10. configs/scheduler_cfgs/dpm_sde.yaml +9 -0
  11. configs/scheduler_cfgs/lms.yaml +9 -0
  12. configs/scheduler_cfgs/pndm.yaml +10 -0
  13. k_diffusion/__init__.py +8 -0
  14. k_diffusion/external.py +181 -0
  15. k_diffusion/sampling.py +702 -0
  16. k_diffusion/utils.py +457 -0
  17. ldm/modules/diffusionmodules/__init__.py +0 -0
  18. ldm/modules/diffusionmodules/model.py +488 -0
  19. ldm/modules/distributions/__init__.py +0 -0
  20. ldm/modules/distributions/distributions.py +92 -0
  21. preprocessor/__init__.py +124 -0
  22. preprocessor/anime2sketch.py +119 -0
  23. preprocessor/anime_segment.py +487 -0
  24. preprocessor/manga_line_extractor.py +187 -0
  25. preprocessor/sk_model.py +94 -0
  26. preprocessor/sketchKeras.py +153 -0
  27. refnet/__init__.py +0 -0
  28. refnet/ldm/__init__.py +1 -0
  29. refnet/ldm/ddpm.py +236 -0
  30. refnet/ldm/openaimodel.py +386 -0
  31. refnet/ldm/util.py +289 -0
  32. refnet/modules/__init__.py +34 -0
  33. refnet/modules/attention.py +309 -0
  34. refnet/modules/attn_utils.py +155 -0
  35. refnet/modules/embedder.py +489 -0
  36. refnet/modules/encoder.py +224 -0
  37. refnet/modules/layers.py +99 -0
  38. refnet/modules/lora.py +370 -0
  39. refnet/modules/proj.py +142 -0
  40. refnet/modules/reference_net.py +430 -0
  41. refnet/modules/transformer.py +232 -0
  42. refnet/modules/unet.py +421 -0
  43. refnet/modules/unet_old.py +596 -0
  44. refnet/sampling/__init__.py +11 -0
  45. refnet/sampling/denoiser.py +181 -0
  46. refnet/sampling/hook.py +257 -0
  47. refnet/sampling/manipulation.py +135 -0
  48. refnet/sampling/sampler.py +192 -0
  49. refnet/sampling/scheduler.py +42 -0
  50. refnet/sampling/tps_transformation.py +203 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+
4
+ from refnet.sampling import get_noise_schedulers, get_sampler_list
5
+ from functools import partial
6
+ from backend import *
7
+
8
+ links = {
9
+ "base": "https://arxiv.org/abs/2401.01456",
10
+ "v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html",
11
+ "v1.5": "https://arxiv.org/abs/2502.19937v1",
12
+ "v2": "https://arxiv.org/abs/2504.06895",
13
+ "xl": "https://arxiv.org/abs/2601.04883",
14
+ "weights": "https://huggingface.co/tellurion/colorizer/tree/main",
15
+ "github": "https://github.com/tellurion-kanata/colorizeDiffusion",
16
+ }
17
+
18
+ def app_options():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0")
21
+ parser.add_argument("--server_port", '-port', type=int, default=7860)
22
+ parser.add_argument("--share", action="store_true")
23
+ parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true")
24
+ return parser.parse_args()
25
+
26
+
27
+ def init_interface(opt, *args, **kwargs) -> None:
28
+ sampler_list = get_sampler_list()
29
+ scheduler_list = get_noise_schedulers()
30
+
31
+ img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png")
32
+ with gr.Blocks(
33
+ title = "Colorize Diffusion",
34
+ css_paths = "backend/style.css",
35
+ theme = gr.themes.Ocean(),
36
+ elem_id = "main-interface",
37
+ analytics_enabled = False,
38
+ fill_width = True
39
+ ) as block:
40
+ with gr.Row(elem_id="header-row", equal_height=True, variant="panel"):
41
+ gr.Markdown(f"""<div class="header-container">
42
+ <div class="app-header"><span class="emoji">🎨</span><span class="title-text">Colorize Diffusion</span></div>
43
+ <div class="paper-links-icons">
44
+ <a href="{links['base']}" target="_blank">
45
+ <img src="https://img.shields.io/badge/arXiv-2407.15886 (base)-B31B1B?style=flat&logo=arXiv" alt="arXiv Paper">
46
+ </a>
47
+ <a href="{links['v1']}" target="_blank">
48
+ <img src="https://img.shields.io/badge/WACV 2025-v1-0CA4A5?style=flat&logo=Semantic%20Web" alt="WACV 2025">
49
+ </a>
50
+ <a href="{links['v1.5']}" target="_blank">
51
+ <img src="https://img.shields.io/badge/CVPR 2025-v1.5-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2025">
52
+ </a>
53
+ <a href="{links['v2']}" target="_blank">
54
+ <img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
55
+ </a>
56
+ <a href="{links['weights']}" target="_blank">
57
+ <img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
58
+ </a>
59
+ <a href="{links['github']}" target="_blank">
60
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub" alt="GitHub">
61
+ </a>
62
+ <a href="https://github.com/tellurion-kanata/colorizeDiffusion/blob/master/LICENSE" target="_blank">
63
+ <img src="https://img.shields.io/badge/License-CC--BY--NC--SA%204.0-4CAF50?style=flat&logo=Creative%20Commons" alt="License">
64
+ </a>
65
+ </div>
66
+ </div>""")
67
+
68
+ with gr.Row(elem_id="content-row", equal_height=False, variant="panel"):
69
+ with gr.Column():
70
+ with gr.Row(visible=opt.enable_text_manipulation):
71
+ target = gr.Textbox(label="Target prompt", value="", scale=2)
72
+ anchor = gr.Textbox(label="Anchor prompt", value="", scale=2)
73
+ control = gr.Textbox(label="Control prompt", value="", scale=2)
74
+ with gr.Row(visible=opt.enable_text_manipulation):
75
+ target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2)
76
+ ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01)
77
+ ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01)
78
+ ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01)
79
+ ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01)
80
+ with gr.Row(visible=opt.enable_text_manipulation):
81
+ enhance = gr.Checkbox(label="Enhance manipulation", value=False)
82
+ add_prompt = gr.Button(value="Add")
83
+ clear_prompt = gr.Button(value="Clear")
84
+ vis_button = gr.Button(value="Visualize")
85
+ text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation)
86
+
87
+ with gr.Row():
88
+ sketch_img = img_block(label="Sketch")
89
+ reference_img = img_block(label="Reference")
90
+ background_img = img_block(label="Background")
91
+
92
+ style_enhance = gr.State(False)
93
+ fg_enhance = gr.State(False)
94
+ with gr.Row():
95
+ bg_enhance = gr.Checkbox(label="Low-level injection", value=False)
96
+ injection = gr.Checkbox(label="Attention injection", value=False)
97
+ autofit_size = gr.Checkbox(label="Autofit size", value=False)
98
+ with gr.Row():
99
+ gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5)
100
+ strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05)
101
+ fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05)
102
+ bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05)
103
+ with gr.Row():
104
+ gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1)
105
+ ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05)
106
+ mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05)
107
+ merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05)
108
+ with gr.Row():
109
+ bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1)
110
+ width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2)
111
+ with gr.Row():
112
+ step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1)
113
+ height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2)
114
+
115
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1)
116
+ with gr.Accordion("Advanced Settings", open=False):
117
+ with gr.Row():
118
+ crop = gr.Checkbox(label="Crop result", value=False, scale=1)
119
+ remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2)
120
+ rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2)
121
+ latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2)
122
+ with gr.Row():
123
+ injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0,
124
+ maximum=2.0, value=0, step=0.05)
125
+ injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0,
126
+ maximum=1.0, value=0.5, step=0.05)
127
+ injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0,
128
+ value=0, step=0.05)
129
+
130
+ with gr.Row():
131
+ reuse_seed = gr.Button(value="Reuse Seed")
132
+ random_seed = gr.Button(value="Random Seed")
133
+
134
+ with gr.Column():
135
+ result_gallery = gr.Gallery(
136
+ label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png"
137
+ )
138
+ run_button = gr.Button("Generate", variant="primary", size="lg")
139
+ with gr.Row():
140
+ mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01)
141
+ mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01)
142
+ pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05)
143
+
144
+ with gr.Row():
145
+ sd_model = gr.Dropdown(choices=get_available_models(), label="Models",
146
+ value=get_available_models()[0])
147
+ extractor_model = gr.Dropdown(choices=line_extractor_list,
148
+ label="Line extractor", value=default_line_extractor)
149
+ mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor",
150
+ value=default_mask_extractor)
151
+ with gr.Row():
152
+ sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler")
153
+ scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler")
154
+ preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"],
155
+ label="Sketch preprocessor", value="invert")
156
+
157
+ with gr.Row():
158
+ deterministic = gr.Checkbox(label="Deterministic batch seed", value=False)
159
+ save_memory = gr.Checkbox(label="Save memory", value=True)
160
+
161
+ # Hidden states for unused advanced controls
162
+ fg_disentangle_scale = gr.State(1.0)
163
+ start_step = gr.State(0.0)
164
+ end_step = gr.State(1.0)
165
+ no_start_step = gr.State(-0.05)
166
+ no_end_step = gr.State(-0.05)
167
+ return_inter = gr.State(False)
168
+ accurate = gr.State(False)
169
+ enc_scale = gr.State(1.0)
170
+ middle_scale = gr.State(1.0)
171
+ low_scale = gr.State(1.0)
172
+ ctl_scale_1 = gr.State(1.0)
173
+ ctl_scale_2 = gr.State(1.0)
174
+ ctl_scale_3 = gr.State(1.0)
175
+ ctl_scale_4 = gr.State(1.0)
176
+
177
+ add_prompt.click(fn=apppend_prompt,
178
+ inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt],
179
+ outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt])
180
+ clear_prompt.click(fn=clear_prompts, outputs=[text_prompt])
181
+
182
+ reuse_seed.click(fn=get_last_seed, outputs=[seed])
183
+ random_seed.click(fn=reset_random_seed, outputs=[seed])
184
+
185
+ extractor_model.input(fn=switch_extractor, inputs=[extractor_model])
186
+ sd_model.input(fn=load_model, inputs=[sd_model])
187
+ mask_model.input(fn=switch_mask_extractor, inputs=[mask_model])
188
+
189
+ ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
190
+ bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
191
+ ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale,
192
+ mask_scale, height, width, seed, save_memory, step, injection, autofit_size,
193
+ remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step,
194
+ crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler,
195
+ preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3,
196
+ enhance, accurate, enc_scale, middle_scale, low_scale, strength]
197
+
198
+ run_button.click(
199
+ fn = inference,
200
+ inputs = ips,
201
+ outputs = [result_gallery],
202
+ )
203
+
204
+ vis_button.click(
205
+ fn = visualize,
206
+ inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3],
207
+ outputs = [result_gallery],
208
+ )
209
+
210
+ block.launch(
211
+ server_name = opt.server_name,
212
+ share = opt.share,
213
+ server_port = opt.server_port,
214
+ )
215
+
216
+
217
+ if __name__ == '__main__':
218
+ opt = app_options()
219
+ try:
220
+ models = get_available_models()
221
+ load_model(models[0])
222
+ switch_extractor(default_line_extractor)
223
+ switch_mask_extractor(default_mask_extractor)
224
+ interface = init_interface(opt)
225
+ except Exception as e:
226
+ print(f"Error initializing interface: {e}")
227
+ raise
backend/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .appfunc import *
2
+
3
+
4
+ __all__ = [
5
+ 'switch_extractor', 'switch_mask_extractor',
6
+ 'get_available_models', 'load_model', 'inference', 'reset_random_seed', 'get_last_seed',
7
+ 'apppend_prompt', 'clear_prompts', 'visualize',
8
+ 'default_line_extractor', 'default_mask_extractor', 'MAXM_INT32',
9
+ 'mask_extractor_list', 'line_extractor_list',
10
+ ]
11
+
12
+
13
+ default_line_extractor = "lineart_keras"
14
+ default_mask_extractor = "rmbg-v2"
15
+ mask_extractor_list = ["none", "ISNet", "rmbg-v2", "BiRefNet", "BiRefNet_HR"]
16
+ line_extractor_list = ["lineart", "lineart_denoise", "lineart_keras", "lineart_sk"]
backend/appfunc.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import traceback
4
+ import gradio as gr
5
+ import os.path as osp
6
+
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from omegaconf import OmegaConf
10
+ from refnet.util import instantiate_from_config
11
+ from preprocessor import create_model
12
+ from .functool import *
13
+
14
+ model = None
15
+
16
+ model_type = ""
17
+ current_checkpoint = ""
18
+ global_seed = None
19
+
20
+ smask_extractor = create_model("ISNet-sketch").cpu()
21
+
22
+ MAXM_INT32 = 429496729
23
+
24
+ # HuggingFace model repository
25
+ HF_REPO_ID = "tellurion/colorizer"
26
+ MODEL_CACHE_DIR = "models"
27
+
28
+ # Model registry: filename -> model_type
29
+ MODEL_REGISTRY = {
30
+ "sdxl.safetensors": "sdxl",
31
+ "xlv2.safetensors": "xlv2",
32
+ }
33
+
34
+ model_types = ["sdxl", "xlv2"]
35
+
36
+ '''
37
+ Gradio UI functions
38
+ '''
39
+
40
+
41
+ def get_available_models():
42
+ """Return list of available model names from registry."""
43
+ return list(MODEL_REGISTRY.keys())
44
+
45
+
46
+ def download_model(filename):
47
+ """Download a model from HuggingFace Hub if not already cached."""
48
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
49
+ local_path = osp.join(MODEL_CACHE_DIR, filename)
50
+ if osp.exists(local_path):
51
+ return local_path
52
+
53
+ print(f"Downloading {filename} from {HF_REPO_ID}...")
54
+ gr.Info(f"Downloading {filename}...")
55
+ path = hf_hub_download(
56
+ repo_id=HF_REPO_ID,
57
+ filename=filename,
58
+ local_dir=MODEL_CACHE_DIR,
59
+ )
60
+ print(f"Downloaded to {path}")
61
+ return path
62
+
63
+
64
+ def switch_extractor(type):
65
+ global line_extractor
66
+ try:
67
+ line_extractor = create_model(type)
68
+ gr.Info(f"Switched to {type} extractor")
69
+ except Exception as e:
70
+ print(f"Error info: {e}")
71
+ print(traceback.print_exc())
72
+ gr.Info(f"Failed in loading {type} extractor")
73
+
74
+
75
+ def switch_mask_extractor(type):
76
+ global mask_extractor
77
+ try:
78
+ mask_extractor = create_model(type)
79
+ gr.Info(f"Switched to {type} extractor")
80
+ except Exception as e:
81
+ print(f"Error info: {e}")
82
+ print(traceback.print_exc())
83
+ gr.Info(f"Failed in loading {type} extractor")
84
+
85
+
86
+ def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt):
87
+ target = target.strip()
88
+ anchor = anchor.strip()
89
+ control = control.strip()
90
+ if target == "": target = "none"
91
+ if anchor == "": anchor = "none"
92
+ if control == "": control = "none"
93
+ new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; "
94
+ f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}")
95
+ return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip()
96
+
97
+
98
+ def clear_prompts():
99
+ return ""
100
+
101
+
102
+ def load_model(ckpt_name):
103
+ global model, model_type, current_checkpoint
104
+ config_root = "configs/inference"
105
+
106
+ try:
107
+ # Determine model type from registry or filename prefix
108
+ new_model_type = MODEL_REGISTRY.get(ckpt_name, "")
109
+ if not new_model_type:
110
+ for key in model_types:
111
+ if ckpt_name.startswith(key):
112
+ new_model_type = key
113
+ break
114
+
115
+ if model_type != new_model_type or not "model" in globals():
116
+ if "model" in globals() and exists(model):
117
+ del model
118
+ config_path = osp.join(config_root, f"{new_model_type}.yaml")
119
+ new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
120
+ print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...")
121
+ model = new_model
122
+
123
+ # Download model from HF Hub
124
+ local_path = download_model(ckpt_name)
125
+
126
+ model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v"
127
+ model.init_from_ckpt(local_path, logging=True)
128
+ model.switch_to_fp16()
129
+
130
+ model_type = new_model_type
131
+ current_checkpoint = ckpt_name
132
+ print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].")
133
+ gr.Info("Loaded model successfully.")
134
+
135
+ except Exception as e:
136
+ print(f"Error type: {e}")
137
+ print(traceback.print_exc())
138
+ gr.Info("Failed in loading model.")
139
+
140
+
141
+ def get_last_seed():
142
+ return global_seed or -1
143
+
144
+
145
+ def reset_random_seed():
146
+ return -1
147
+
148
+
149
+ def visualize(reference, text, *args):
150
+ return visualize_heatmaps(model, reference, parse_prompts(text), *args)
151
+
152
+
153
+ def set_cas_scales(accurate, cas_args):
154
+ enc_scale, middle_scale, low_scale, strength = cas_args[:4]
155
+ if not accurate:
156
+ scale_strength = {
157
+ "level_control": True,
158
+ "scales": {
159
+ "encoder": enc_scale * strength,
160
+ "middle": middle_scale * strength,
161
+ "low": low_scale * strength,
162
+ }
163
+ }
164
+ else:
165
+ scale_strength = {
166
+ "level_control": False,
167
+ "scales": list(cas_args[4:])
168
+ }
169
+ return scale_strength
170
+
171
+
172
+ @torch.no_grad()
173
+ def inference(
174
+ style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
175
+ bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
176
+ ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4,
177
+ fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step,
178
+ injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale,
179
+ start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess,
180
+ deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate,
181
+ *args
182
+ ):
183
+ global global_seed, line_extractor, mask_extractor
184
+ global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32)
185
+ torch.manual_seed(global_seed)
186
+
187
+ # Auto-fit size based on sketch dimensions
188
+ if autofit_size and exists(input_s):
189
+ sketch_w, sketch_h = input_s.size
190
+ aspect_ratio = sketch_w / sketch_h
191
+ target_area = 1024 * 1024
192
+ new_h = int((target_area / aspect_ratio) ** 0.5)
193
+ new_w = int(new_h * aspect_ratio)
194
+ height = ((new_h + 16) // 32) * 32
195
+ width = ((new_w + 16) // 32) * 32
196
+ height = max(768, min(1536, height))
197
+ width = max(768, min(1536, width))
198
+ gr.Info(f"Auto-fitted size: {width}x{height}")
199
+
200
+ smask, rmask, bgmask = None, None, None
201
+ manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance)
202
+ inputs = preprocessing_inputs(
203
+ sketch = input_s,
204
+ reference = input_r,
205
+ background = input_bg,
206
+ preprocess = preprocess,
207
+ hook = injection,
208
+ resolution = (height, width),
209
+ extractor = line_extractor,
210
+ pad_scale = pad_scale,
211
+ )
212
+ sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs
213
+
214
+ cond = {"reference": reference, "sketch": sketch, "background": background}
215
+ mask_guided = bg_enhance or fg_enhance
216
+
217
+ if exists(white_sketch) and exists(reference) and mask_guided:
218
+ mask_extractor.cuda()
219
+ smask_extractor.cuda()
220
+ smask = smask_extractor.proceed(
221
+ x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False
222
+ )
223
+
224
+ if exists(background) and remove_fg:
225
+ bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True)
226
+ filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background))
227
+ cond.update({"background": filtered_background, "rmask": bgmask})
228
+ else:
229
+ rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True)
230
+ cond.update({"rmask": rmask})
231
+ rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask))
232
+ cond.update({"smask": smask})
233
+ smask_extractor.cpu()
234
+ mask_extractor.cpu()
235
+
236
+ scale_strength = set_cas_scales(accurate, args)
237
+ ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4]
238
+ ctl_scales = [t * ctl_scale for t in ctl_scales]
239
+
240
+ results = model.generate(
241
+ # Colorization mode
242
+ style_enhance = style_enhance,
243
+ bg_enhance = bg_enhance,
244
+ fg_enhance = fg_enhance,
245
+ fg_disentangle_scale = fg_disentangle_scale,
246
+ latent_inpaint = latent_inpaint,
247
+
248
+ # Conditional inputs
249
+ cond = cond,
250
+ ctl_scale = ctl_scales,
251
+ merge_scale = merge_scale,
252
+ mask_scale = mask_scale,
253
+ mask_thresh = mask_ts,
254
+ mask_thresh_sketch = mask_ss,
255
+
256
+ # Sampling settings
257
+ bs = bs,
258
+ gs = [gs_r, gs_s],
259
+ sampler = sampler,
260
+ scheduler = scheduler,
261
+ start_step = start_step,
262
+ end_step = end_step,
263
+ no_start_step = no_start_step,
264
+ no_end_step = no_end_step,
265
+ strength = scale_strength,
266
+ fg_strength = fg_strength,
267
+ bg_strength = bg_strength,
268
+ seed = global_seed,
269
+ deterministic = deterministic,
270
+ height = height,
271
+ width = width,
272
+ step = step,
273
+
274
+ # Injection settings
275
+ injection = injection,
276
+ injection_cfg = infid_r,
277
+ injection_control = infid_x,
278
+ injection_start_step = injstep,
279
+ hook_xr = inject_xr,
280
+ hook_xs = inject_xs,
281
+
282
+ # Additional settings
283
+ low_vram = low_vram,
284
+ return_intermediate = return_inter,
285
+ manipulation_params = manipulation_params,
286
+ )
287
+
288
+ if rmbg:
289
+ mask_extractor.cuda()
290
+ mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1)
291
+ results = torch.where(mask >= mask_ss, results, torch.ones_like(results))
292
+ mask_extractor.cpu()
293
+
294
+ results = postprocess(results, sketch, reference, background, crop, original_shape,
295
+ mask_guided, smask, rmask, bgmask, mask_ts, mask_ss)
296
+ torch.cuda.empty_cache()
297
+ gr.Info("Generation completed.")
298
+ return results
backend/functool.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import PIL.Image as Image
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+
9
+ from functools import partial
10
+
11
+ maxium_resolution = 4096
12
+ token_length = int(256 ** 0.5)
13
+
14
+ def exists(v):
15
+ return v is not None
16
+
17
+ resize = partial(transforms.Resize, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
18
+
19
+ def resize_image(img, new_size, w, h):
20
+ if w > h:
21
+ img = resize((int(h / w * new_size), new_size))(img)
22
+ else:
23
+ img = resize((new_size, int(w / h * new_size)))(img)
24
+ return img
25
+
26
+ def pad_image(image: torch.Tensor, h, w):
27
+ b, c, height, width = image.shape
28
+ square_image = -torch.ones([b, c, h, w], device=image.device)
29
+ left = (w - width) // 2
30
+ top = (h - height) // 2
31
+ square_image[:, :, top:top+height, left:left+width] = image
32
+
33
+ return square_image, (left, top, width, height)
34
+
35
+
36
+ def pad_image_with_margin(image: Image, scale):
37
+ w, h = image.size
38
+ nw = int(w * scale)
39
+ bg = Image.new('RGB', (nw, h), (255, 255, 255))
40
+ bg.paste(image, ((nw-w)//2, 0))
41
+ return bg
42
+
43
+
44
+ def crop_image_from_square(square_image, original_dim):
45
+ left, top, width, height = original_dim
46
+ return square_image.crop((left, top, left + width, top + height))
47
+
48
+
49
+ def to_tensor(x, inverse=False):
50
+ x = transforms.ToTensor()(x).unsqueeze(0)
51
+ x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
52
+ return x if not inverse else -x
53
+
54
+ def to_numpy(x, denormalize=True):
55
+ if denormalize:
56
+ return ((x.clamp(-1, 1) + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
57
+ else:
58
+ return (x.clamp(0, 1) * 255)[0][0].cpu().numpy().astype(np.uint8)
59
+
60
+ def lineart_standard(x: Image.Image):
61
+ x = np.array(x).astype(np.float32)
62
+ g = cv2.GaussianBlur(x, (0, 0), 6.0)
63
+ intensity = np.min(g - x, axis=2).clip(0, 255)
64
+ intensity /= max(16, np.median(intensity[intensity > 8]))
65
+ intensity *= 127
66
+ intensity = np.repeat(np.expand_dims(intensity, 2), 3, axis=2)
67
+ result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
68
+ return result
69
+
70
+ def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new=False):
71
+ w, h = sketch.size
72
+ th, tw = resolution
73
+ r = min(th/h, tw/w)
74
+
75
+ if preprocess == "none":
76
+ sketch = to_tensor(sketch)
77
+ elif preprocess == "invert":
78
+ sketch = to_tensor(sketch, inverse=True)
79
+ elif preprocess == "invert-webui":
80
+ sketch = lineart_standard(sketch)
81
+ else:
82
+ sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
83
+
84
+ sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
85
+ if new:
86
+ sketch = ((sketch + 1) / 2.).clamp(0, 1)
87
+ white_sketch = 1 - sketch
88
+ else:
89
+ white_sketch = -sketch
90
+ return sketch, original_shape, white_sketch
91
+
92
+
93
+ @torch.no_grad()
94
+ def preprocessing_inputs(
95
+ sketch: Image.Image,
96
+ reference: Image.Image,
97
+ background: Image.Image,
98
+ preprocess: str,
99
+ hook: bool,
100
+ resolution: tuple[int, int],
101
+ extractor: nn.Module,
102
+ pad_scale: float = 1.,
103
+ new = False
104
+ ):
105
+ extractor = extractor.cuda()
106
+ h, w = resolution
107
+ if exists(sketch):
108
+ sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor, new)
109
+ else:
110
+ sketch = torch.zeros([1, 3, h, w], device="cuda") if new else -torch.ones([1, 3, h, w], device="cuda")
111
+ white_sketch = None
112
+ original_shape = (0, 0, h, w)
113
+
114
+ inject_xs = None
115
+ if hook:
116
+ assert exists(reference) and exists(extractor)
117
+ maxm = max(h, w)
118
+ # inject_xs = resize((h, w))(extractor.proceed(resize((maxm, maxm))(reference)).repeat(1, 3, 1, 1))
119
+ inject_xr = to_tensor(resize((h, w))(reference))
120
+ else:
121
+ inject_xr = None
122
+ extractor = extractor.cpu()
123
+
124
+ if exists(reference):
125
+ if pad_scale > 1.:
126
+ reference = pad_image_with_margin(reference, pad_scale)
127
+ reference = to_tensor(reference)
128
+
129
+ if exists(background):
130
+ if pad_scale > 1.:
131
+ background = pad_image_with_margin(background, pad_scale)
132
+ background = to_tensor(background)
133
+
134
+ return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
135
+
136
+ def postprocess(results, sketch, reference, background, crop, original_shape,
137
+ mask_guided, smask, rmask, bgmask, mask_ts, mask_ss, new=False):
138
+ results = to_numpy(results)
139
+ sketch = to_numpy(sketch, not new)[0]
140
+
141
+ results_list = []
142
+ for result in results:
143
+ result = Image.fromarray(result)
144
+ if crop:
145
+ result = crop_image_from_square(result, original_shape)
146
+ results_list.append(result)
147
+
148
+ results_list.append(sketch)
149
+
150
+ if exists(reference):
151
+ reference = to_numpy(reference)[0]
152
+ results_list.append(reference)
153
+ # if vis_crossattn:
154
+ # results_list += visualize_attention_map(reference, results_list[0], vh, vw)
155
+
156
+ if exists(background):
157
+ background = to_numpy(background)[0]
158
+ results_list.append(background)
159
+
160
+ if exists(bgmask):
161
+ background = Image.fromarray(background)
162
+ results_list.append(Image.composite(
163
+ background,
164
+ Image.new("RGB", background.size, (255, 255, 255)),
165
+ Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
166
+ ))
167
+ results_list.append(Image.composite(
168
+ Image.new("RGB", background.size, (255, 255, 255)),
169
+ background,
170
+ Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
171
+ ))
172
+
173
+ if mask_guided:
174
+ smask[smask < mask_ss] = 0
175
+ results_list.append(Image.fromarray(to_numpy(smask, denormalize=False), mode="L"))
176
+
177
+ if exists(rmask):
178
+ reference = Image.fromarray(reference)
179
+ rmask[rmask < mask_ts] = 0
180
+ results_list.append(Image.fromarray(to_numpy(rmask, denormalize=False), mode="L"))
181
+ results_list.append(Image.composite(
182
+ reference,
183
+ Image.new("RGB", reference.size, (255, 255, 255)),
184
+ Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
185
+ ))
186
+ results_list.append(Image.composite(
187
+ Image.new("RGB", reference.size, (255, 255, 255)),
188
+ reference,
189
+ Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
190
+ ))
191
+
192
+ return results_list
193
+
194
+
195
+ def parse_prompts(
196
+ prompts: str,
197
+ target: bool = None,
198
+ anchor: bool = None,
199
+ control: bool = None,
200
+ target_scale: bool = None,
201
+ ts0: float = None,
202
+ ts1: float = None,
203
+ ts2: float = None,
204
+ ts3: float = None,
205
+ enhance: bool = None
206
+ ):
207
+
208
+ targets = []
209
+ anchors = []
210
+ controls = []
211
+ scales = []
212
+ enhances = []
213
+ thresholds_list = []
214
+
215
+ replace_str = ["; [anchor] ", "; [control] ", "; [scale]", "; [enhanced]", "; [ts0]", "; [ts1]", "; [ts2]", "; [ts3]"]
216
+ if prompts != "" and prompts is not None:
217
+ ps_l = prompts.split('\n')
218
+ for ps in ps_l:
219
+ ps = ps.replace("[target] ", "")
220
+ for str in replace_str:
221
+ ps = ps.replace(str, "||||")
222
+
223
+ p_l = ps.split("||||")
224
+ targets.append(p_l[0])
225
+ anchors.append(p_l[1])
226
+ controls.append(p_l[2])
227
+ scales.append(float(p_l[3]))
228
+ enhances.append(bool(p_l[4]))
229
+ thresholds_list.append([float(p_l[5]), float(p_l[6]), float(p_l[7]), float(p_l[8])])
230
+
231
+ if exists(target) and target != "":
232
+ targets.append(target)
233
+ anchors.append(anchor)
234
+ controls.append(control)
235
+ scales.append(target_scale)
236
+ enhances.append(enhance)
237
+ thresholds_list.append([ts0, ts1, ts2, ts3])
238
+
239
+ return {
240
+ "targets": targets,
241
+ "anchors": anchors,
242
+ "controls": controls,
243
+ "target_scales": scales,
244
+ "enhances": enhances,
245
+ "thresholds_list": thresholds_list
246
+ }
247
+
248
+
249
+ from refnet.sampling.manipulation import get_heatmaps
250
+ def visualize_heatmaps(model, reference, manipulation_params, control, ts0, ts1, ts2, ts3):
251
+ if reference is None:
252
+ return []
253
+
254
+ size = reference.size
255
+ if size[0] > maxium_resolution or size[1] > maxium_resolution:
256
+ if size[0] > size[1]:
257
+ size = (maxium_resolution, int(float(maxium_resolution) / size[0] * size[1]))
258
+ else:
259
+ size = (int(float(maxium_resolution) / size[1] * size[0]), maxium_resolution)
260
+ reference = reference.resize(size, Image.BICUBIC)
261
+
262
+ reference = np.array(reference)
263
+ scale_maps = get_heatmaps(model, to_tensor(reference), size[1], size[0],
264
+ control, ts0, ts1, ts2, ts3, **manipulation_params)
265
+
266
+ scale_map = scale_maps[0] + scale_maps[1] + scale_maps[2] + scale_maps[3]
267
+ heatmap = cv2.cvtColor(cv2.applyColorMap(scale_map, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
268
+ result = cv2.addWeighted(reference, 0.3, heatmap, 0.7, 0)
269
+ hu = size[1] // token_length
270
+ wu = size[0] // token_length
271
+ for i in range(16):
272
+ result[i * hu, :] = (0, 0, 0)
273
+ for i in range(16):
274
+ result[:, i * wu] = (0, 0, 0)
275
+
276
+ return [result]
backend/style.css ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary-color: #9b59b6;
3
+ --primary-light: #d6c6e1;
4
+ --secondary-color: #2ecc71;
5
+ --text-color: #333333;
6
+ --background-color: #f9f9f9;
7
+ --card-bg: #ffffff;
8
+ --border-radius: 10px;
9
+ --shadow-sm: 0 2px 5px rgba(0, 0, 0, 0.05);
10
+ --shadow-md: 0 5px 15px rgba(0, 0, 0, 0.07);
11
+ --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1);
12
+ --gradient: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
13
+ --input-border: #e0e0e0;
14
+ --input-bg: #ffffff;
15
+ --font-weight-normal: 500;
16
+ --font-weight-bold: 700;
17
+ }
18
+
19
+ /* Base styles */
20
+ body, html {
21
+ margin: 0;
22
+ padding: 0;
23
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
24
+ font-weight: var(--font-weight-normal);
25
+ background-color: var(--background-color);
26
+ color: var(--text-color);
27
+ width: 100vw;
28
+ overflow-x: hidden;
29
+ }
30
+
31
+ * {
32
+ box-sizing: border-box;
33
+ }
34
+
35
+ /* Force full width layout */
36
+ #main-interface,
37
+ .gradio-app,
38
+ .gradio-container {
39
+ width: 100vw !important;
40
+ max-width: 100vw !important;
41
+ margin: 0 !important;
42
+ padding: 0 !important;
43
+ box-shadow: none !important;
44
+ border: none !important;
45
+ overflow-x: hidden !important;
46
+ }
47
+
48
+ /* Header styling */
49
+ #header-row {
50
+ background: white;
51
+ padding: 15px 20px;
52
+ margin-bottom: 20px;
53
+ box-shadow: var(--shadow-sm);
54
+ border-bottom: 1px solid rgba(0,0,0,0.05);
55
+ }
56
+
57
+ .header-container {
58
+ width: 100%;
59
+ display: flex;
60
+ flex-direction: column;
61
+ align-items: center;
62
+ padding: 10px 0;
63
+ }
64
+
65
+ .app-header {
66
+ display: flex;
67
+ align-items: center;
68
+ gap: 12px;
69
+ margin-bottom: 15px;
70
+ }
71
+
72
+ .app-header .emoji {
73
+ font-size: 36px;
74
+ }
75
+
76
+ /* Fix for Colorize Diffusion title visibility */
77
+ .gradio-markdown h1,
78
+ .gradio-markdown h2,
79
+ #header-row h1,
80
+ #header-row h2,
81
+ .title-text,
82
+ .app-header .title-text {
83
+ display: inline-block !important;
84
+ visibility: visible !important;
85
+ opacity: 1 !important;
86
+ position: relative !important;
87
+ color: var(--primary-color) !important;
88
+ font-size: 32px !important;
89
+ font-weight: 800 !important;
90
+ }
91
+
92
+ /* Badge links under the header */
93
+ .paper-links-icons {
94
+ display: flex;
95
+ flex-wrap: wrap;
96
+ justify-content: center;
97
+ gap: 8px;
98
+ margin-top: 5px;
99
+ }
100
+
101
+ .paper-links-icons a {
102
+ transition: transform 0.2s ease;
103
+ opacity: 0.9;
104
+ }
105
+
106
+ .paper-links-icons a:hover {
107
+ transform: translateY(-3px);
108
+ opacity: 1;
109
+ }
110
+
111
+ /* Content layout */
112
+ #content-row {
113
+ padding: 0 20px 20px 20px;
114
+ max-width: 100%;
115
+ margin: 0 auto;
116
+ }
117
+
118
+ /* Apply bold font to all text elements for better readability */
119
+ p, span, label, button, input, textarea, select, .gradio-button, .gradio-checkbox, .gradio-dropdown, .gradio-textbox {
120
+ font-weight: var(--font-weight-normal);
121
+ }
122
+
123
+ /* Make headings bolder */
124
+ h1, h2, h3, h4, h5, h6 {
125
+ font-weight: var(--font-weight-bold);
126
+ }
127
+
128
+ /* Improved font styling for Gradio UI elements */
129
+ .gradio-container,
130
+ .gradio-container *,
131
+ .gradio-app,
132
+ .gradio-app * {
133
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
134
+ font-weight: 500 !important;
135
+ }
136
+
137
+ /* Style for labels and slider labels */
138
+ .gradio-container label,
139
+ .gradio-slider label,
140
+ .gradio-checkbox label,
141
+ .gradio-radio label,
142
+ .gradio-dropdown label,
143
+ .gradio-textbox label,
144
+ .gradio-number label,
145
+ .gradio-button,
146
+ .gradio-checkbox span,
147
+ .gradio-radio span {
148
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
149
+ font-weight: 600 !important;
150
+ letter-spacing: 0.01em;
151
+ }
152
+
153
+ /* Style for buttons */
154
+ button,
155
+ .gradio-button {
156
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
157
+ font-weight: 600 !important;
158
+ }
159
+
160
+ /* Style for input values */
161
+ input,
162
+ textarea,
163
+ select,
164
+ .gradio-textbox textarea,
165
+ .gradio-number input {
166
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
167
+ font-weight: 500 !important;
168
+ }
169
+
170
+ /* Better styling for drop areas */
171
+ .upload-box,
172
+ [data-testid="image"] {
173
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
174
+ font-weight: 500 !important;
175
+ }
176
+
177
+ /* Additional styling for values in sliders and numbers */
178
+ .wrap .wrap .wrap span {
179
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
180
+ font-weight: 600 !important;
181
+ }
configs/inference/sdxl.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-6
3
+ target: refnet.models.colorizerXL.InferenceWrapper
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ timesteps: 1000
8
+ image_size: 128
9
+ channels: 4
10
+ scale_factor: 0.13025
11
+ logits_embed: false
12
+
13
+ unet_config:
14
+ target: refnet.modules.unet.DualCondUNetXL
15
+ params:
16
+ use_checkpoint: True
17
+ in_channels: 4
18
+ out_channels: 4
19
+ model_channels: 320
20
+ adm_in_channels: 512
21
+ # adm_in_channels: 2816
22
+ num_classes: sequential
23
+ attention_resolutions: [4, 2]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4]
26
+ num_head_channels: 64
27
+ use_spatial_transformer: true
28
+ use_linear_in_transformer: true
29
+ transformer_depth: [1, 2, 10]
30
+ context_dim: 2048
31
+
32
+ first_stage_config:
33
+ target: ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ ddconfig:
37
+ double_z: true
38
+ z_channels: 4
39
+ resolution: 512
40
+ in_channels: 3
41
+ out_ch: 3
42
+ ch: 128
43
+ ch_mult: [1, 2, 4, 4]
44
+ num_res_blocks: 2
45
+ attn_resolutions: []
46
+ dropout: 0.0
47
+
48
+ cond_stage_config:
49
+ target: refnet.modules.embedder.HFCLIPVisionModel
50
+ # target: refnet.modules.embedder.FrozenOpenCLIPImageEmbedder
51
+ params:
52
+ arch: ViT-bigG-14
53
+
54
+ control_encoder_config:
55
+ # target: refnet.modules.encoder.MultiEncoder
56
+ target: refnet.modules.encoder.MultiScaleAttentionEncoder
57
+ params:
58
+ in_ch: 3
59
+ model_channels: 320
60
+ ch_mults: [ 1, 2, 4 ]
61
+
62
+ img_embedder_config:
63
+ target: refnet.modules.embedder.WDv14SwinTransformerV2
64
+
65
+ scalar_embedder_config:
66
+ target: refnet.modules.embedder.TimestepEmbedding
67
+ params:
68
+ embed_dim: 256
69
+
70
+ proj_config:
71
+ target: refnet.modules.proj.ClusterConcat
72
+ # target: refnet.modules.proj.RecoveryClusterConcat
73
+ params:
74
+ input_dim: 1280
75
+ c_dim: 1024
76
+ output_dim: 2048
77
+ token_length: 196
78
+ dim_head: 128
79
+ # proj_config:
80
+ # target: refnet.modules.proj.LogitClusterConcat
81
+ # params:
82
+ # input_dim: 1280
83
+ # c_dim: 1024
84
+ # output_dim: 2048
85
+ # token_length: 196
86
+ # dim_head: 128
87
+ # mlp_in_dim: 9083
88
+ # mlp_ckpt_path: pretrained_models/proj.safetensors
configs/inference/xlv2.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-6
3
+ target: refnet.models.v2-colorizerXL.InferenceWrapperXL
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ timesteps: 1000
8
+ image_size: 128
9
+ channels: 4
10
+ scale_factor: 0.13025
11
+ controller: true
12
+
13
+ unet_config:
14
+ target: refnet.modules.unet.DualCondUNetXL
15
+ params:
16
+ use_checkpoint: True
17
+ in_channels: 4
18
+ in_channels_fg: 4
19
+ out_channels: 4
20
+ model_channels: 320
21
+ adm_in_channels: 512
22
+ num_classes: sequential
23
+ attention_resolutions: [4, 2]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4]
26
+ num_head_channels: 64
27
+ use_spatial_transformer: true
28
+ use_linear_in_transformer: true
29
+ transformer_depth: [1, 2, 10]
30
+ context_dim: 2048
31
+ map_module: false
32
+ warp_module: false
33
+ style_modulation: false
34
+
35
+ bg_encoder_config:
36
+ target: refnet.modules.unet.ReferenceNet
37
+ params:
38
+ use_checkpoint: True
39
+ in_channels: 6
40
+ model_channels: 320
41
+ adm_in_channels: 1024
42
+ num_classes: sequential
43
+ attention_resolutions: [ 4, 2 ]
44
+ num_res_blocks: 2
45
+ channel_mult: [ 1, 2, 4 ]
46
+ num_head_channels: 64
47
+ use_spatial_transformer: true
48
+ use_linear_in_transformer: true
49
+ disable_cross_attentions: true
50
+ context_dim: 2048
51
+ transformer_depth: [ 1, 2, 10 ]
52
+
53
+
54
+ first_stage_config:
55
+ target: ldm.models.autoencoder.AutoencoderKL
56
+ params:
57
+ embed_dim: 4
58
+ ddconfig:
59
+ double_z: true
60
+ z_channels: 4
61
+ resolution: 512
62
+ in_channels: 3
63
+ out_ch: 3
64
+ ch: 128
65
+ ch_mult: [1, 2, 4, 4]
66
+ num_res_blocks: 2
67
+ attn_resolutions: []
68
+ dropout: 0.0
69
+
70
+ cond_stage_config:
71
+ target: refnet.modules.embedder.HFCLIPVisionModel
72
+ params:
73
+ arch: ViT-bigG-14
74
+
75
+ img_embedder_config:
76
+ target: refnet.modules.embedder.WDv14SwinTransformerV2
77
+
78
+ control_encoder_config:
79
+ target: refnet.modules.encoder.MultiScaleAttentionEncoder
80
+ params:
81
+ in_ch: 3
82
+ model_channels: 320
83
+ ch_mults: [1, 2, 4]
84
+
85
+ proj_config:
86
+ target: refnet.modules.proj.ClusterConcat
87
+ # target: refnet.modules.proj.RecoveryClusterConcat
88
+ params:
89
+ input_dim: 1280
90
+ c_dim: 1024
91
+ output_dim: 2048
92
+ token_length: 196
93
+ dim_head: 128
94
+
95
+ scalar_embedder_config:
96
+ target: refnet.modules.embedder.TimestepEmbedding
97
+ params:
98
+ embed_dim: 256
99
+
100
+ lora_config:
101
+ lora_params: [
102
+ {
103
+ label: background,
104
+ root_module: model.diffusion_model,
105
+ target_keys: [ attn2.to_q, attn2.to_k, attn2.to_v ],
106
+ r: 4,
107
+ }
108
+ ]
configs/scheduler_cfgs/ddim.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ beta_start: 0.00085
2
+ beta_end: 0.012
3
+ beta_schedule: "scaled_linear"
4
+ clip_sample: false
5
+ steps_offset: 1
6
+
7
+ ### Zero-SNR params
8
+ #rescale_betas_zero_snr: True
9
+ #timestep_spacing: "trailing"
10
+ timestep_spacing: "leading"
configs/scheduler_cfgs/dpm.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ beta_start: 0.00085
2
+ beta_end: 0.012
3
+ beta_schedule: "scaled_linear"
4
+ steps_offset: 1
5
+
6
+ ### Zero-SNR params
7
+ #rescale_betas_zero_snr: True
8
+ timestep_spacing: "leading"
configs/scheduler_cfgs/dpm_sde.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ beta_start: 0.00085
2
+ beta_end: 0.012
3
+ beta_schedule: "scaled_linear"
4
+ steps_offset: 1
5
+
6
+ ### Zero-SNR params
7
+ #rescale_betas_zero_snr: True
8
+ timestep_spacing: "leading"
9
+ algorithm_type: sde-dpmsolver++
configs/scheduler_cfgs/lms.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ beta_start: 0.00085
2
+ beta_end: 0.012
3
+ beta_schedule: "scaled_linear"
4
+ #clip_sample: false
5
+ steps_offset: 1
6
+
7
+ ### Zero-SNR params
8
+ #rescale_betas_zero_snr: True
9
+ timestep_spacing: "leading"
configs/scheduler_cfgs/pndm.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ beta_start: 0.00085
2
+ beta_end: 0.012
3
+ beta_schedule: "scaled_linear"
4
+ #clip_sample: false
5
+ steps_offset: 1
6
+
7
+ ### Zero-SNR params
8
+ #rescale_betas_zero_snr: True
9
+ #timestep_spacing: "trailing"
10
+ timestep_spacing: "leading"
k_diffusion/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .sampling import *
2
+
3
+
4
+ def create_noise_sampler(x, sigmas, seed):
5
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
6
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
7
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
8
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
k_diffusion/external.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from . import sampling, utils
7
+
8
+
9
+ class VDenoiser(nn.Module):
10
+ """A v-diffusion-pytorch model wrapper for k-diffusion."""
11
+
12
+ def __init__(self, inner_model):
13
+ super().__init__()
14
+ self.inner_model = inner_model
15
+ self.sigma_data = 1.
16
+
17
+ def get_scalings(self, sigma):
18
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
19
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
20
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
21
+ return c_skip, c_out, c_in
22
+
23
+ def sigma_to_t(self, sigma):
24
+ return sigma.atan() / math.pi * 2
25
+
26
+ def t_to_sigma(self, t):
27
+ return (t * math.pi / 2).tan()
28
+
29
+ def loss(self, input, noise, sigma, **kwargs):
30
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
31
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
32
+ model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
33
+ target = (input - c_skip * noised_input) / c_out
34
+ return (model_output - target).pow(2).flatten(1).mean(1)
35
+
36
+ def forward(self, input, sigma, **kwargs):
37
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
38
+ return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
39
+
40
+
41
+ class DiscreteSchedule(nn.Module):
42
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
43
+ levels."""
44
+
45
+ def __init__(self, sigmas, quantize):
46
+ super().__init__()
47
+ self.register_buffer('sigmas', sigmas)
48
+ self.register_buffer('log_sigmas', sigmas.log())
49
+ self.quantize = quantize
50
+
51
+ @property
52
+ def sigma_min(self):
53
+ return self.sigmas[0]
54
+
55
+ @property
56
+ def sigma_max(self):
57
+ return self.sigmas[-1]
58
+
59
+ def get_sigmas(self, n=None):
60
+ if n is None:
61
+ return sampling.append_zero(self.sigmas.flip(0))
62
+ t_max = len(self.sigmas) - 1
63
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
64
+ return sampling.append_zero(self.t_to_sigma(t))
65
+
66
+ def sigma_to_t(self, sigma, quantize=None):
67
+ quantize = self.quantize if quantize is None else quantize
68
+ log_sigma = sigma.log()
69
+ dists = log_sigma - self.log_sigmas[:, None]
70
+ if quantize:
71
+ return dists.abs().argmin(dim=0).view(sigma.shape)
72
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
73
+ high_idx = low_idx + 1
74
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
75
+ w = (low - log_sigma) / (low - high)
76
+ w = w.clamp(0, 1)
77
+ t = (1 - w) * low_idx + w * high_idx
78
+ return t.view(sigma.shape)
79
+
80
+ def t_to_sigma(self, t):
81
+ t = t.float()
82
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
83
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
84
+ return log_sigma.exp()
85
+
86
+
87
+ class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
88
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
89
+ noise)."""
90
+
91
+ def __init__(self, model, alphas_cumprod, quantize):
92
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
93
+ self.inner_model = model
94
+ self.sigma_data = 1.
95
+
96
+ def get_scalings(self, sigma):
97
+ c_out = -sigma
98
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
99
+ return c_out, c_in
100
+
101
+ def get_eps(self, *args, **kwargs):
102
+ return self.inner_model(*args, **kwargs)
103
+
104
+ def loss(self, input, noise, sigma, **kwargs):
105
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
106
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
107
+ eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
108
+ return (eps - noise).pow(2).flatten(1).mean(1)
109
+
110
+ def forward(self, input, sigma, **kwargs):
111
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
112
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
113
+ return input + eps * c_out
114
+
115
+
116
+ class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
117
+ """A wrapper for OpenAI diffusion models."""
118
+
119
+ def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
120
+ alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
121
+ super().__init__(model, alphas_cumprod, quantize=quantize)
122
+ self.has_learned_sigmas = has_learned_sigmas
123
+
124
+ def get_eps(self, *args, **kwargs):
125
+ model_output = self.inner_model(*args, **kwargs)
126
+ if self.has_learned_sigmas:
127
+ return model_output.chunk(2, dim=1)[0]
128
+ return model_output
129
+
130
+
131
+ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
132
+ """A wrapper for CompVis diffusion models."""
133
+
134
+ def __init__(self, model, quantize=False, device='cpu'):
135
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
136
+ self.sigmas = self.sigmas.to(device)
137
+ self.log_sigmas = self.log_sigmas.to(device)
138
+
139
+ def get_eps(self, *args, **kwargs):
140
+ return self.inner_model.apply_model(*args, **kwargs)
141
+
142
+
143
+ class DiscreteVDDPMDenoiser(DiscreteSchedule):
144
+ """A wrapper for discrete schedule DDPM models that output v."""
145
+
146
+ def __init__(self, model, alphas_cumprod, quantize):
147
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
148
+ self.inner_model = model
149
+ self.sigma_data = 1.
150
+
151
+ def get_scalings(self, sigma):
152
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
153
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
154
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
155
+ return c_skip, c_out, c_in
156
+
157
+ def get_v(self, *args, **kwargs):
158
+ return self.inner_model(*args, **kwargs)
159
+
160
+ def loss(self, input, noise, sigma, **kwargs):
161
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
162
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
163
+ model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
164
+ target = (input - c_skip * noised_input) / c_out
165
+ return (model_output - target).pow(2).flatten(1).mean(1)
166
+
167
+ def forward(self, input, sigma, **kwargs):
168
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
169
+ return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
170
+
171
+
172
+ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
173
+ """A wrapper for CompVis diffusion models that output v."""
174
+
175
+ def __init__(self, model, quantize=False, device='cpu'):
176
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
177
+ self.sigmas = self.sigmas.to(device)
178
+ self.log_sigmas = self.log_sigmas.to(device)
179
+
180
+ def get_v(self, x, t, cond, **kwargs):
181
+ return self.inner_model.apply_model(x, t, cond)
k_diffusion/sampling.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ from torchdiffeq import odeint
7
+ import torchsde
8
+ from tqdm.auto import trange, tqdm
9
+
10
+ from . import utils
11
+
12
+
13
+ def append_zero(x):
14
+ return torch.cat([x, x.new_zeros([1])])
15
+
16
+
17
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
18
+ """Constructs the noise schedule of Karras et al. (2022)."""
19
+ ramp = torch.linspace(0, 1, n).to(device)
20
+ min_inv_rho = sigma_min ** (1 / rho)
21
+ max_inv_rho = sigma_max ** (1 / rho)
22
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
23
+ return append_zero(sigmas).to(device)
24
+
25
+
26
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
27
+ """Constructs an exponential noise schedule."""
28
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
29
+ return append_zero(sigmas)
30
+
31
+
32
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
33
+ """Constructs an polynomial in log sigma noise schedule."""
34
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
35
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
36
+ return append_zero(sigmas)
37
+
38
+
39
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
40
+ """Constructs a continuous VP noise schedule."""
41
+ t = torch.linspace(1, eps_s, n, device=device)
42
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
43
+ return append_zero(sigmas)
44
+
45
+
46
+ def to_d(x, sigma, denoised):
47
+ """Converts a denoiser output to a Karras ODE derivative."""
48
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
49
+
50
+
51
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
52
+ """Calculates the noise level (sigma_down) to step down to and the amount
53
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
54
+ if not eta:
55
+ return sigma_to, 0.
56
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
57
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
58
+ return sigma_down, sigma_up
59
+
60
+
61
+ def default_noise_sampler(x):
62
+ return lambda sigma, sigma_next: torch.randn_like(x)
63
+
64
+
65
+ class BatchedBrownianTree:
66
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
67
+
68
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
69
+ t0, t1, self.sign = self.sort(t0, t1)
70
+ w0 = kwargs.get('w0', torch.zeros_like(x))
71
+ if seed is None:
72
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
73
+ self.batched = True
74
+ try:
75
+ assert len(seed) == x.shape[0]
76
+ w0 = w0[0]
77
+ except TypeError:
78
+ seed = [seed]
79
+ self.batched = False
80
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
81
+
82
+ @staticmethod
83
+ def sort(a, b):
84
+ return (a, b, 1) if a < b else (b, a, -1)
85
+
86
+ def __call__(self, t0, t1):
87
+ t0, t1, sign = self.sort(t0, t1)
88
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
89
+ return w if self.batched else w[0]
90
+
91
+
92
+ class BrownianTreeNoiseSampler:
93
+ """A noise sampler backed by a torchsde.BrownianTree.
94
+
95
+ Args:
96
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
97
+ random samples.
98
+ sigma_min (float): The low end of the valid interval.
99
+ sigma_max (float): The high end of the valid interval.
100
+ seed (int or List[int]): The random seed. If a list of seeds is
101
+ supplied instead of a single integer, then the noise sampler will
102
+ use one BrownianTree per batch item, each with its own seed.
103
+ transform (callable): A function that maps sigma to the sampler's
104
+ internal timestep.
105
+ """
106
+
107
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
108
+ self.transform = transform
109
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
110
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
111
+
112
+ def __call__(self, sigma, sigma_next):
113
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
114
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
115
+
116
+
117
+ @torch.no_grad()
118
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
119
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
120
+ extra_args = {} if extra_args is None else extra_args
121
+ s_in = x.new_ones([x.shape[0]])
122
+ for i in trange(len(sigmas) - 1, disable=disable):
123
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
124
+ eps = torch.randn_like(x) * s_noise
125
+ sigma_hat = sigmas[i] * (gamma + 1)
126
+ if gamma > 0:
127
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
128
+ denoised = model(x, sigma_hat * s_in, **extra_args)
129
+ d = to_d(x, sigma_hat, denoised)
130
+ if callback is not None:
131
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
132
+ dt = sigmas[i + 1] - sigma_hat
133
+ # Euler method
134
+ x = x + d * dt
135
+ return x
136
+
137
+
138
+ @torch.no_grad()
139
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
140
+ """Ancestral sampling with Euler method steps."""
141
+ extra_args = {} if extra_args is None else extra_args
142
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
143
+ s_in = x.new_ones([x.shape[0]])
144
+ for i in trange(len(sigmas) - 1, disable=disable):
145
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
146
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
147
+ if callback is not None:
148
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
149
+ d = to_d(x, sigmas[i], denoised)
150
+ # Euler method
151
+ dt = sigma_down - sigmas[i]
152
+ x = x + d * dt
153
+ if sigmas[i + 1] > 0:
154
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
155
+ return x
156
+
157
+
158
+ @torch.no_grad()
159
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
160
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
161
+ extra_args = {} if extra_args is None else extra_args
162
+ s_in = x.new_ones([x.shape[0]])
163
+ for i in trange(len(sigmas) - 1, disable=disable):
164
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
165
+ eps = torch.randn_like(x) * s_noise
166
+ sigma_hat = sigmas[i] * (gamma + 1)
167
+ if gamma > 0:
168
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
169
+ denoised = model(x, sigma_hat * s_in, **extra_args)
170
+ d = to_d(x, sigma_hat, denoised)
171
+ if callback is not None:
172
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
173
+ dt = sigmas[i + 1] - sigma_hat
174
+ if sigmas[i + 1] == 0:
175
+ # Euler method
176
+ x = x + d * dt
177
+ else:
178
+ # Heun's method
179
+ x_2 = x + d * dt
180
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
181
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
182
+ d_prime = (d + d_2) / 2
183
+ x = x + d_prime * dt
184
+ return x
185
+
186
+
187
+ @torch.no_grad()
188
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
189
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
190
+ extra_args = {} if extra_args is None else extra_args
191
+ s_in = x.new_ones([x.shape[0]])
192
+ for i in trange(len(sigmas) - 1, disable=disable):
193
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
194
+ eps = torch.randn_like(x) * s_noise
195
+ sigma_hat = sigmas[i] * (gamma + 1)
196
+ if gamma > 0:
197
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
198
+ denoised = model(x, sigma_hat * s_in, **extra_args)
199
+ d = to_d(x, sigma_hat, denoised)
200
+ if callback is not None:
201
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
202
+ if sigmas[i + 1] == 0:
203
+ # Euler method
204
+ dt = sigmas[i + 1] - sigma_hat
205
+ x = x + d * dt
206
+ else:
207
+ # DPM-Solver-2
208
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
209
+ dt_1 = sigma_mid - sigma_hat
210
+ dt_2 = sigmas[i + 1] - sigma_hat
211
+ x_2 = x + d * dt_1
212
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
213
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
214
+ x = x + d_2 * dt_2
215
+ return x
216
+
217
+
218
+ @torch.no_grad()
219
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
220
+ """Ancestral sampling with DPM-Solver second-order steps."""
221
+ extra_args = {} if extra_args is None else extra_args
222
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
223
+ s_in = x.new_ones([x.shape[0]])
224
+ for i in trange(len(sigmas) - 1, disable=disable):
225
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
226
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
227
+ if callback is not None:
228
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
229
+ d = to_d(x, sigmas[i], denoised)
230
+ if sigma_down == 0:
231
+ # Euler method
232
+ dt = sigma_down - sigmas[i]
233
+ x = x + d * dt
234
+ else:
235
+ # DPM-Solver-2
236
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
237
+ dt_1 = sigma_mid - sigmas[i]
238
+ dt_2 = sigma_down - sigmas[i]
239
+ x_2 = x + d * dt_1
240
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
241
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
242
+ x = x + d_2 * dt_2
243
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
244
+ return x
245
+
246
+
247
+ def linear_multistep_coeff(order, t, i, j):
248
+ if order - 1 > i:
249
+ raise ValueError(f'Order {order} too high for step {i}')
250
+ def fn(tau):
251
+ prod = 1.
252
+ for k in range(order):
253
+ if j == k:
254
+ continue
255
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
256
+ return prod
257
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
258
+
259
+
260
+ @torch.no_grad()
261
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
262
+ extra_args = {} if extra_args is None else extra_args
263
+ s_in = x.new_ones([x.shape[0]])
264
+ sigmas_cpu = sigmas.detach().cpu().numpy()
265
+ ds = []
266
+ for i in trange(len(sigmas) - 1, disable=disable):
267
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
268
+ d = to_d(x, sigmas[i], denoised)
269
+ ds.append(d)
270
+ if len(ds) > order:
271
+ ds.pop(0)
272
+ if callback is not None:
273
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
274
+ cur_order = min(i + 1, order)
275
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
276
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
277
+ return x
278
+
279
+
280
+ @torch.no_grad()
281
+ def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
282
+ extra_args = {} if extra_args is None else extra_args
283
+ s_in = x.new_ones([x.shape[0]])
284
+ v = torch.randint_like(x, 2) * 2 - 1
285
+ fevals = 0
286
+ def ode_fn(sigma, x):
287
+ nonlocal fevals
288
+ with torch.enable_grad():
289
+ x = x[0].detach().requires_grad_()
290
+ denoised = model(x, sigma * s_in, **extra_args)
291
+ d = to_d(x, sigma, denoised)
292
+ fevals += 1
293
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
294
+ d_ll = (v * grad).flatten(1).sum(1)
295
+ return d.detach(), d_ll
296
+ x_min = x, x.new_zeros([x.shape[0]])
297
+ t = x.new_tensor([sigma_min, sigma_max])
298
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
299
+ latent, delta_ll = sol[0][-1], sol[1][-1]
300
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
301
+ return ll_prior + delta_ll, {'fevals': fevals}
302
+
303
+
304
+ class PIDStepSizeController:
305
+ """A PID controller for ODE adaptive step size control."""
306
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
307
+ self.h = h
308
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
309
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
310
+ self.b3 = dcoeff / order
311
+ self.accept_safety = accept_safety
312
+ self.eps = eps
313
+ self.errs = []
314
+
315
+ def limiter(self, x):
316
+ return 1 + math.atan(x - 1)
317
+
318
+ def propose_step(self, error):
319
+ inv_error = 1 / (float(error) + self.eps)
320
+ if not self.errs:
321
+ self.errs = [inv_error, inv_error, inv_error]
322
+ self.errs[0] = inv_error
323
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
324
+ factor = self.limiter(factor)
325
+ accept = factor >= self.accept_safety
326
+ if accept:
327
+ self.errs[2] = self.errs[1]
328
+ self.errs[1] = self.errs[0]
329
+ self.h *= factor
330
+ return accept
331
+
332
+
333
+ class DPMSolver(nn.Module):
334
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
335
+
336
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
337
+ super().__init__()
338
+ self.model = model
339
+ self.extra_args = {} if extra_args is None else extra_args
340
+ self.eps_callback = eps_callback
341
+ self.info_callback = info_callback
342
+
343
+ def t(self, sigma):
344
+ return -sigma.log()
345
+
346
+ def sigma(self, t):
347
+ return t.neg().exp()
348
+
349
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
350
+ if key in eps_cache:
351
+ return eps_cache[key], eps_cache
352
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
353
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
354
+ if self.eps_callback is not None:
355
+ self.eps_callback()
356
+ return eps, {key: eps, **eps_cache}
357
+
358
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
359
+ eps_cache = {} if eps_cache is None else eps_cache
360
+ h = t_next - t
361
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
362
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
363
+ return x_1, eps_cache
364
+
365
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
366
+ eps_cache = {} if eps_cache is None else eps_cache
367
+ h = t_next - t
368
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
369
+ s1 = t + r1 * h
370
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
371
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
372
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
373
+ return x_2, eps_cache
374
+
375
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
376
+ eps_cache = {} if eps_cache is None else eps_cache
377
+ h = t_next - t
378
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
379
+ s1 = t + r1 * h
380
+ s2 = t + r2 * h
381
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
382
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
383
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
384
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
385
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
386
+ return x_3, eps_cache
387
+
388
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
389
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
390
+ if not t_end > t_start and eta:
391
+ raise ValueError('eta must be 0 for reverse sampling')
392
+
393
+ m = math.floor(nfe / 3) + 1
394
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
395
+
396
+ if nfe % 3 == 0:
397
+ orders = [3] * (m - 2) + [2, 1]
398
+ else:
399
+ orders = [3] * (m - 1) + [nfe % 3]
400
+
401
+ for i in range(len(orders)):
402
+ eps_cache = {}
403
+ t, t_next = ts[i], ts[i + 1]
404
+ if eta:
405
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
406
+ t_next_ = torch.minimum(t_end, self.t(sd))
407
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
408
+ else:
409
+ t_next_, su = t_next, 0.
410
+
411
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
412
+ denoised = x - self.sigma(t) * eps
413
+ if self.info_callback is not None:
414
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
415
+
416
+ if orders[i] == 1:
417
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
418
+ elif orders[i] == 2:
419
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
420
+ else:
421
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
422
+
423
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
424
+
425
+ return x
426
+
427
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
428
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
429
+ if order not in {2, 3}:
430
+ raise ValueError('order should be 2 or 3')
431
+ forward = t_end > t_start
432
+ if not forward and eta:
433
+ raise ValueError('eta must be 0 for reverse sampling')
434
+ h_init = abs(h_init) * (1 if forward else -1)
435
+ atol = torch.tensor(atol)
436
+ rtol = torch.tensor(rtol)
437
+ s = t_start
438
+ x_prev = x
439
+ accept = True
440
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
441
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
442
+
443
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
444
+ eps_cache = {}
445
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
446
+ if eta:
447
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
448
+ t_ = torch.minimum(t_end, self.t(sd))
449
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
450
+ else:
451
+ t_, su = t, 0.
452
+
453
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
454
+ denoised = x - self.sigma(s) * eps
455
+
456
+ if order == 2:
457
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
458
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
459
+ else:
460
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
461
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
462
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
463
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
464
+ accept = pid.propose_step(error)
465
+ if accept:
466
+ x_prev = x_low
467
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
468
+ s = t
469
+ info['n_accept'] += 1
470
+ else:
471
+ info['n_reject'] += 1
472
+ info['nfe'] += order
473
+ info['steps'] += 1
474
+
475
+ if self.info_callback is not None:
476
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
477
+
478
+ return x, info
479
+
480
+
481
+ @torch.no_grad()
482
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
483
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
484
+ if sigma_min <= 0 or sigma_max <= 0:
485
+ raise ValueError('sigma_min and sigma_max must not be 0')
486
+ with tqdm(total=n, disable=disable) as pbar:
487
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
488
+ if callback is not None:
489
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
490
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
491
+
492
+
493
+ @torch.no_grad()
494
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
495
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
496
+ if sigma_min <= 0 or sigma_max <= 0:
497
+ raise ValueError('sigma_min and sigma_max must not be 0')
498
+ with tqdm(disable=disable) as pbar:
499
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
500
+ if callback is not None:
501
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
502
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
503
+ if return_info:
504
+ return x, info
505
+ return x
506
+
507
+
508
+ @torch.no_grad()
509
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
510
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
511
+ extra_args = {} if extra_args is None else extra_args
512
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
513
+ s_in = x.new_ones([x.shape[0]])
514
+ sigma_fn = lambda t: t.neg().exp()
515
+ t_fn = lambda sigma: sigma.log().neg()
516
+
517
+ for i in trange(len(sigmas) - 1, disable=disable):
518
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
519
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
520
+ if callback is not None:
521
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
522
+ if sigma_down == 0:
523
+ # Euler method
524
+ d = to_d(x, sigmas[i], denoised)
525
+ dt = sigma_down - sigmas[i]
526
+ x = x + d * dt
527
+ else:
528
+ # DPM-Solver++(2S)
529
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
530
+ r = 1 / 2
531
+ h = t_next - t
532
+ s = t + r * h
533
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
534
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
535
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
536
+ # Noise addition
537
+ if sigmas[i + 1] > 0:
538
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
539
+ return x
540
+
541
+
542
+ @torch.no_grad()
543
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
544
+ """DPM-Solver++ (stochastic)."""
545
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
546
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
547
+ extra_args = {} if extra_args is None else extra_args
548
+ s_in = x.new_ones([x.shape[0]])
549
+ sigma_fn = lambda t: t.neg().exp()
550
+ t_fn = lambda sigma: sigma.log().neg()
551
+
552
+ for i in trange(len(sigmas) - 1, disable=disable):
553
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
554
+ if callback is not None:
555
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
556
+ if sigmas[i + 1] == 0:
557
+ # Euler method
558
+ d = to_d(x, sigmas[i], denoised)
559
+ dt = sigmas[i + 1] - sigmas[i]
560
+ x = x + d * dt
561
+ else:
562
+ # DPM-Solver++
563
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
564
+ h = t_next - t
565
+ s = t + h * r
566
+ fac = 1 / (2 * r)
567
+
568
+ # Step 1
569
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
570
+ s_ = t_fn(sd)
571
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
572
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
573
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
574
+
575
+ # Step 2
576
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
577
+ t_next_ = t_fn(sd)
578
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
579
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
580
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
581
+ return x
582
+
583
+
584
+ @torch.no_grad()
585
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
586
+ """DPM-Solver++(2M)."""
587
+ extra_args = {} if extra_args is None else extra_args
588
+ s_in = x.new_ones([x.shape[0]])
589
+ sigma_fn = lambda t: t.neg().exp()
590
+ t_fn = lambda sigma: sigma.log().neg()
591
+ old_denoised = None
592
+
593
+ for i in trange(len(sigmas) - 1, disable=disable):
594
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
595
+ if callback is not None:
596
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
597
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
598
+ h = t_next - t
599
+ if old_denoised is None or sigmas[i + 1] == 0:
600
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
601
+ else:
602
+ h_last = t - t_fn(sigmas[i - 1])
603
+ r = h_last / h
604
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
605
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
606
+ old_denoised = denoised
607
+ return x
608
+
609
+
610
+ @torch.no_grad()
611
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
612
+ """DPM-Solver++(2M) SDE."""
613
+
614
+ if solver_type not in {'heun', 'midpoint'}:
615
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
616
+
617
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
618
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
619
+ extra_args = {} if extra_args is None else extra_args
620
+ s_in = x.new_ones([x.shape[0]])
621
+
622
+ old_denoised = None
623
+ h_last = None
624
+
625
+ for i in trange(len(sigmas) - 1, disable=disable):
626
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
627
+ if callback is not None:
628
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
629
+ if sigmas[i + 1] == 0:
630
+ # Denoising step
631
+ x = denoised
632
+ else:
633
+ # DPM-Solver++(2M) SDE
634
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
635
+ h = s - t
636
+ eta_h = eta * h
637
+
638
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
639
+
640
+ if old_denoised is not None:
641
+ r = h_last / h
642
+ if solver_type == 'heun':
643
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
644
+ elif solver_type == 'midpoint':
645
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
646
+
647
+ if eta:
648
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
649
+
650
+ old_denoised = denoised
651
+ h_last = h
652
+ return x
653
+
654
+
655
+ @torch.no_grad()
656
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
657
+ """DPM-Solver++(3M) SDE."""
658
+
659
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
660
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
661
+ extra_args = {} if extra_args is None else extra_args
662
+ s_in = x.new_ones([x.shape[0]])
663
+
664
+ denoised_1, denoised_2 = None, None
665
+ h_1, h_2 = None, None
666
+
667
+ for i in trange(len(sigmas) - 1, disable=disable):
668
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
669
+ if callback is not None:
670
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
671
+ if sigmas[i + 1] == 0:
672
+ # Denoising step
673
+ x = denoised
674
+ else:
675
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
676
+ h = s - t
677
+ h_eta = h * (eta + 1)
678
+
679
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
680
+
681
+ if h_2 is not None:
682
+ r0 = h_1 / h
683
+ r1 = h_2 / h
684
+ d1_0 = (denoised - denoised_1) / r0
685
+ d1_1 = (denoised_1 - denoised_2) / r1
686
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
687
+ d2 = (d1_0 - d1_1) / (r0 + r1)
688
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
689
+ phi_3 = phi_2 / h_eta - 0.5
690
+ x = x + phi_2 * d1 - phi_3 * d2
691
+ elif h_1 is not None:
692
+ r = h_1 / h
693
+ d = (denoised - denoised_1) / r
694
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
695
+ x = x + phi_2 * d
696
+
697
+ if eta:
698
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
699
+
700
+ denoised_1, denoised_2 = denoised, denoised_1
701
+ h_1, h_2 = h, h_1
702
+ return x
k_diffusion/utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import hashlib
3
+ import math
4
+ from pathlib import Path
5
+ import shutil
6
+ import threading
7
+ import urllib
8
+ import warnings
9
+
10
+ from PIL import Image
11
+ import safetensors
12
+ import torch
13
+ from torch import nn, optim
14
+ from torch.utils import data
15
+ from torchvision.transforms import functional as TF
16
+
17
+
18
+ def from_pil_image(x):
19
+ """Converts from a PIL image to a tensor."""
20
+ x = TF.to_tensor(x)
21
+ if x.ndim == 2:
22
+ x = x[..., None]
23
+ return x * 2 - 1
24
+
25
+
26
+ def to_pil_image(x):
27
+ """Converts from a tensor to a PIL image."""
28
+ if x.ndim == 4:
29
+ assert x.shape[0] == 1
30
+ x = x[0]
31
+ if x.shape[0] == 1:
32
+ x = x[0]
33
+ return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
34
+
35
+
36
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
37
+ """Apply passed in transforms for HuggingFace Datasets."""
38
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
39
+ return {image_key: images}
40
+
41
+
42
+ def append_dims(x, target_dims):
43
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
44
+ dims_to_append = target_dims - x.ndim
45
+ if dims_to_append < 0:
46
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
47
+ return x[(...,) + (None,) * dims_to_append]
48
+
49
+
50
+ def n_params(module):
51
+ """Returns the number of trainable parameters in a module."""
52
+ return sum(p.numel() for p in module.parameters())
53
+
54
+
55
+ def download_file(path, url, digest=None):
56
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
57
+ path = Path(path)
58
+ path.parent.mkdir(parents=True, exist_ok=True)
59
+ if not path.exists():
60
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
61
+ shutil.copyfileobj(response, f)
62
+ if digest is not None:
63
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
64
+ if digest != file_digest:
65
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
66
+ return path
67
+
68
+
69
+ @contextmanager
70
+ def train_mode(model, mode=True):
71
+ """A context manager that places a model into training mode and restores
72
+ the previous mode on exit."""
73
+ modes = [module.training for module in model.modules()]
74
+ try:
75
+ yield model.train(mode)
76
+ finally:
77
+ for i, module in enumerate(model.modules()):
78
+ module.training = modes[i]
79
+
80
+
81
+ def eval_mode(model):
82
+ """A context manager that places a model into evaluation mode and restores
83
+ the previous mode on exit."""
84
+ return train_mode(model, False)
85
+
86
+
87
+ @torch.no_grad()
88
+ def ema_update(model, averaged_model, decay):
89
+ """Incorporates updated model parameters into an exponential moving averaged
90
+ version of a model. It should be called after each optimizer step."""
91
+ model_params = dict(model.named_parameters())
92
+ averaged_params = dict(averaged_model.named_parameters())
93
+ assert model_params.keys() == averaged_params.keys()
94
+
95
+ for name, param in model_params.items():
96
+ averaged_params[name].lerp_(param, 1 - decay)
97
+
98
+ model_buffers = dict(model.named_buffers())
99
+ averaged_buffers = dict(averaged_model.named_buffers())
100
+ assert model_buffers.keys() == averaged_buffers.keys()
101
+
102
+ for name, buf in model_buffers.items():
103
+ averaged_buffers[name].copy_(buf)
104
+
105
+
106
+ class EMAWarmup:
107
+ """Implements an EMA warmup using an inverse decay schedule.
108
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
109
+ good values for models you plan to train for a million or more steps (reaches decay
110
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
111
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
112
+ 215.4k steps).
113
+ Args:
114
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
115
+ power (float): Exponential factor of EMA warmup. Default: 1.
116
+ min_value (float): The minimum EMA decay rate. Default: 0.
117
+ max_value (float): The maximum EMA decay rate. Default: 1.
118
+ start_at (int): The epoch to start averaging at. Default: 0.
119
+ last_epoch (int): The index of last epoch. Default: 0.
120
+ """
121
+
122
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
123
+ last_epoch=0):
124
+ self.inv_gamma = inv_gamma
125
+ self.power = power
126
+ self.min_value = min_value
127
+ self.max_value = max_value
128
+ self.start_at = start_at
129
+ self.last_epoch = last_epoch
130
+
131
+ def state_dict(self):
132
+ """Returns the state of the class as a :class:`dict`."""
133
+ return dict(self.__dict__.items())
134
+
135
+ def load_state_dict(self, state_dict):
136
+ """Loads the class's state.
137
+ Args:
138
+ state_dict (dict): scaler state. Should be an object returned
139
+ from a call to :meth:`state_dict`.
140
+ """
141
+ self.__dict__.update(state_dict)
142
+
143
+ def get_value(self):
144
+ """Gets the current EMA decay rate."""
145
+ epoch = max(0, self.last_epoch - self.start_at)
146
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
147
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
148
+
149
+ def step(self):
150
+ """Updates the step count."""
151
+ self.last_epoch += 1
152
+
153
+
154
+ class InverseLR(optim.lr_scheduler._LRScheduler):
155
+ """Implements an inverse decay learning rate schedule with an optional exponential
156
+ warmup. When last_epoch=-1, sets initial lr as lr.
157
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
158
+ (1 / 2)**power of its original value.
159
+ Args:
160
+ optimizer (Optimizer): Wrapped optimizer.
161
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
162
+ power (float): Exponential factor of learning rate decay. Default: 1.
163
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
164
+ Default: 0.
165
+ min_lr (float): The minimum learning rate. Default: 0.
166
+ last_epoch (int): The index of last epoch. Default: -1.
167
+ verbose (bool): If ``True``, prints a message to stdout for
168
+ each update. Default: ``False``.
169
+ """
170
+
171
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
172
+ last_epoch=-1, verbose=False):
173
+ self.inv_gamma = inv_gamma
174
+ self.power = power
175
+ if not 0. <= warmup < 1:
176
+ raise ValueError('Invalid value for warmup')
177
+ self.warmup = warmup
178
+ self.min_lr = min_lr
179
+ super().__init__(optimizer, last_epoch, verbose)
180
+
181
+ def get_lr(self):
182
+ if not self._get_lr_called_within_step:
183
+ warnings.warn("To get the last learning rate computed by the scheduler, "
184
+ "please use `get_last_lr()`.")
185
+
186
+ return self._get_closed_form_lr()
187
+
188
+ def _get_closed_form_lr(self):
189
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
190
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
191
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
192
+ for base_lr in self.base_lrs]
193
+
194
+
195
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
196
+ """Implements an exponential learning rate schedule with an optional exponential
197
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
198
+ continuously by decay (default 0.5) every num_steps steps.
199
+ Args:
200
+ optimizer (Optimizer): Wrapped optimizer.
201
+ num_steps (float): The number of steps to decay the learning rate by decay in.
202
+ decay (float): The factor by which to decay the learning rate every num_steps
203
+ steps. Default: 0.5.
204
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
205
+ Default: 0.
206
+ min_lr (float): The minimum learning rate. Default: 0.
207
+ last_epoch (int): The index of last epoch. Default: -1.
208
+ verbose (bool): If ``True``, prints a message to stdout for
209
+ each update. Default: ``False``.
210
+ """
211
+
212
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
213
+ last_epoch=-1, verbose=False):
214
+ self.num_steps = num_steps
215
+ self.decay = decay
216
+ if not 0. <= warmup < 1:
217
+ raise ValueError('Invalid value for warmup')
218
+ self.warmup = warmup
219
+ self.min_lr = min_lr
220
+ super().__init__(optimizer, last_epoch, verbose)
221
+
222
+ def get_lr(self):
223
+ if not self._get_lr_called_within_step:
224
+ warnings.warn("To get the last learning rate computed by the scheduler, "
225
+ "please use `get_last_lr()`.")
226
+
227
+ return self._get_closed_form_lr()
228
+
229
+ def _get_closed_form_lr(self):
230
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
231
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
232
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
233
+ for base_lr in self.base_lrs]
234
+
235
+
236
+ class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler):
237
+ """Implements a constant learning rate schedule with an optional exponential
238
+ warmup. When last_epoch=-1, sets initial lr as lr.
239
+ Args:
240
+ optimizer (Optimizer): Wrapped optimizer.
241
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
242
+ Default: 0.
243
+ last_epoch (int): The index of last epoch. Default: -1.
244
+ verbose (bool): If ``True``, prints a message to stdout for
245
+ each update. Default: ``False``.
246
+ """
247
+
248
+ def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False):
249
+ if not 0. <= warmup < 1:
250
+ raise ValueError('Invalid value for warmup')
251
+ self.warmup = warmup
252
+ super().__init__(optimizer, last_epoch, verbose)
253
+
254
+ def get_lr(self):
255
+ if not self._get_lr_called_within_step:
256
+ warnings.warn("To get the last learning rate computed by the scheduler, "
257
+ "please use `get_last_lr()`.")
258
+
259
+ return self._get_closed_form_lr()
260
+
261
+ def _get_closed_form_lr(self):
262
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
263
+ return [warmup * base_lr for base_lr in self.base_lrs]
264
+
265
+
266
+ def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
267
+ """Draws stratified samples from a uniform distribution."""
268
+ if groups <= 0:
269
+ raise ValueError(f"groups must be positive, got {groups}")
270
+ if group < 0 or group >= groups:
271
+ raise ValueError(f"group must be in [0, {groups})")
272
+ n = shape[-1] * groups
273
+ offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
274
+ u = torch.rand(shape, dtype=dtype, device=device)
275
+ return (offsets + u) / n
276
+
277
+
278
+ stratified_settings = threading.local()
279
+
280
+
281
+ @contextmanager
282
+ def enable_stratified(group=0, groups=1, disable=False):
283
+ """A context manager that enables stratified sampling."""
284
+ try:
285
+ stratified_settings.disable = disable
286
+ stratified_settings.group = group
287
+ stratified_settings.groups = groups
288
+ yield
289
+ finally:
290
+ del stratified_settings.disable
291
+ del stratified_settings.group
292
+ del stratified_settings.groups
293
+
294
+
295
+ @contextmanager
296
+ def enable_stratified_accelerate(accelerator, disable=False):
297
+ """A context manager that enables stratified sampling, distributing the strata across
298
+ all processes and gradient accumulation steps using settings from Hugging Face Accelerate."""
299
+ try:
300
+ rank = accelerator.process_index
301
+ world_size = accelerator.num_processes
302
+ acc_steps = accelerator.gradient_state.num_steps
303
+ acc_step = accelerator.step % acc_steps
304
+ group = rank * acc_steps + acc_step
305
+ groups = world_size * acc_steps
306
+ with enable_stratified(group, groups, disable=disable):
307
+ yield
308
+ finally:
309
+ pass
310
+
311
+
312
+ def stratified_with_settings(shape, dtype=None, device=None):
313
+ """Draws stratified samples from a uniform distribution, using settings from a context
314
+ manager."""
315
+ if not hasattr(stratified_settings, 'disable') or stratified_settings.disable:
316
+ return torch.rand(shape, dtype=dtype, device=device)
317
+ return stratified_uniform(
318
+ shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device
319
+ )
320
+
321
+
322
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
323
+ """Draws samples from an lognormal distribution."""
324
+ u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7
325
+ return torch.distributions.Normal(loc, scale).icdf(u).exp()
326
+
327
+
328
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
329
+ """Draws samples from an optionally truncated log-logistic distribution."""
330
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
331
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
332
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
333
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
334
+ u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
335
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
336
+
337
+
338
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
339
+ """Draws samples from an log-uniform distribution."""
340
+ min_value = math.log(min_value)
341
+ max_value = math.log(max_value)
342
+ return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
343
+
344
+
345
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
346
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
347
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
348
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
349
+ u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
350
+ return torch.tan(u * math.pi / 2) * sigma_data
351
+
352
+
353
+ def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
354
+ """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
355
+
356
+ def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
357
+ t_min = math.atan(math.exp(-0.5 * logsnr_max))
358
+ t_max = math.atan(math.exp(-0.5 * logsnr_min))
359
+ return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
360
+
361
+ def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
362
+ shift = 2 * math.log(noise_d / image_d)
363
+ return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
364
+
365
+ def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
366
+ logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max)
367
+ logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max)
368
+ return torch.lerp(logsnr_low, logsnr_high, t)
369
+
370
+ logsnr_min = -2 * math.log(min_value / sigma_data)
371
+ logsnr_max = -2 * math.log(max_value / sigma_data)
372
+ u = stratified_with_settings(shape, device=device, dtype=dtype)
373
+ logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
374
+ return torch.exp(-logsnr / 2) * sigma_data
375
+
376
+
377
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
378
+ """Draws samples from a split lognormal distribution."""
379
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
380
+ u = torch.rand(shape, device=device, dtype=dtype)
381
+ n_left = n * -scale_1 + loc
382
+ n_right = n * scale_2 + loc
383
+ ratio = scale_1 / (scale_1 + scale_2)
384
+ return torch.where(u < ratio, n_left, n_right).exp()
385
+
386
+
387
+ class FolderOfImages(data.Dataset):
388
+ """Recursively finds all images in a directory. It does not support
389
+ classes/targets."""
390
+
391
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
392
+
393
+ def __init__(self, root, transform=None):
394
+ super().__init__()
395
+ self.root = Path(root)
396
+ self.transform = nn.Identity() if transform is None else transform
397
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
398
+
399
+ def __repr__(self):
400
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
401
+
402
+ def __len__(self):
403
+ return len(self.paths)
404
+
405
+ def __getitem__(self, key):
406
+ path = self.paths[key]
407
+ with open(path, 'rb') as f:
408
+ image = Image.open(f).convert('RGB')
409
+ image = self.transform(image)
410
+ return image,
411
+
412
+
413
+ class CSVLogger:
414
+ def __init__(self, filename, columns):
415
+ self.filename = Path(filename)
416
+ self.columns = columns
417
+ if self.filename.exists():
418
+ self.file = open(self.filename, 'a')
419
+ else:
420
+ self.file = open(self.filename, 'w')
421
+ self.write(*self.columns)
422
+
423
+ def write(self, *args):
424
+ print(*args, sep=',', file=self.file, flush=True)
425
+
426
+
427
+ @contextmanager
428
+ def tf32_mode(cudnn=None, matmul=None):
429
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
430
+ cudnn_old = torch.backends.cudnn.allow_tf32
431
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
432
+ try:
433
+ if cudnn is not None:
434
+ torch.backends.cudnn.allow_tf32 = cudnn
435
+ if matmul is not None:
436
+ torch.backends.cuda.matmul.allow_tf32 = matmul
437
+ yield
438
+ finally:
439
+ if cudnn is not None:
440
+ torch.backends.cudnn.allow_tf32 = cudnn_old
441
+ if matmul is not None:
442
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
443
+
444
+
445
+ def get_safetensors_metadata(path):
446
+ """Retrieves the metadata from a safetensors file."""
447
+ return safetensors.safe_open(path, "pt").metadata()
448
+
449
+
450
+ def ema_update_dict(values, updates, decay):
451
+ for k, v in updates.items():
452
+ if k not in values:
453
+ values[k] = v
454
+ else:
455
+ values[k] *= decay
456
+ values[k] += (1 - decay) * v
457
+ return values
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+ from typing import Optional, Any
9
+
10
+ from refnet.util import checkpoint_wrapper, default
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ attn_processor = xformers.ops.memory_efficient_attention
18
+ except:
19
+ XFORMERS_IS_AVAILBLE = False
20
+ attn_processor = F.scaled_dot_product_attention
21
+
22
+
23
+ def get_timestep_embedding(timesteps, embedding_dim):
24
+ """
25
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
26
+ From Fairseq.
27
+ Build sinusoidal embeddings.
28
+ This matches the implementation in tensor2tensor, but differs slightly
29
+ from the description in Section 3.5 of "Attention Is All You Need".
30
+ """
31
+ assert len(timesteps.shape) == 1
32
+
33
+ half_dim = embedding_dim // 2
34
+ emb = math.log(10000) / (half_dim - 1)
35
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
36
+ emb = emb.to(device=timesteps.device)
37
+ emb = timesteps.float()[:, None] * emb[None, :]
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
39
+ if embedding_dim % 2 == 1: # zero pad
40
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
41
+ return emb
42
+
43
+
44
+ def nonlinearity(x):
45
+ # swish
46
+ return x*torch.sigmoid(x)
47
+
48
+
49
+ def Normalize(in_channels, num_groups=32):
50
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
51
+
52
+
53
+ class Upsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ self.conv = torch.nn.Conv2d(in_channels,
59
+ in_channels,
60
+ kernel_size=3,
61
+ stride=1,
62
+ padding=1)
63
+
64
+ def forward(self, x):
65
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
66
+ if self.with_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__()
74
+ self.with_conv = with_conv
75
+ if self.with_conv:
76
+ # no asymmetric padding in torch conv, must do it ourselves
77
+ self.conv = torch.nn.Conv2d(in_channels,
78
+ in_channels,
79
+ kernel_size=3,
80
+ stride=2,
81
+ padding=0)
82
+
83
+ def forward(self, x):
84
+ if self.with_conv:
85
+ pad = (0,1,0,1)
86
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
87
+ x = self.conv(x)
88
+ else:
89
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
90
+ return x
91
+
92
+
93
+ class ResnetBlock(nn.Module):
94
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
95
+ dropout, temb_channels=512):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ out_channels = in_channels if out_channels is None else out_channels
99
+ self.out_channels = out_channels
100
+ self.use_conv_shortcut = conv_shortcut
101
+
102
+ self.norm1 = Normalize(in_channels)
103
+ self.conv1 = torch.nn.Conv2d(in_channels,
104
+ out_channels,
105
+ kernel_size=3,
106
+ stride=1,
107
+ padding=1)
108
+ if temb_channels > 0:
109
+ self.temb_proj = torch.nn.Linear(temb_channels,
110
+ out_channels)
111
+ self.norm2 = Normalize(out_channels)
112
+ self.dropout = torch.nn.Dropout(dropout)
113
+ self.conv2 = torch.nn.Conv2d(out_channels,
114
+ out_channels,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1)
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
121
+ out_channels,
122
+ kernel_size=3,
123
+ stride=1,
124
+ padding=1)
125
+ else:
126
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
127
+ out_channels,
128
+ kernel_size=1,
129
+ stride=1,
130
+ padding=0)
131
+
132
+ @checkpoint_wrapper
133
+ def forward(self, x, temb=None):
134
+ h = x
135
+ h = self.norm1(h)
136
+ h = nonlinearity(h)
137
+ h = self.conv1(h)
138
+
139
+ if temb is not None:
140
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
141
+
142
+ h = self.norm2(h)
143
+ h = nonlinearity(h)
144
+ h = self.dropout(h)
145
+ h = self.conv2(h)
146
+
147
+ if self.in_channels != self.out_channels:
148
+ if self.use_conv_shortcut:
149
+ x = self.conv_shortcut(x)
150
+ else:
151
+ x = self.nin_shortcut(x)
152
+
153
+ return x+h
154
+
155
+
156
+ class AttnBlock(nn.Module):
157
+ def __init__(self, in_channels):
158
+ super().__init__()
159
+ self.in_channels = in_channels
160
+
161
+ self.norm = Normalize(in_channels)
162
+ self.q = torch.nn.Conv2d(in_channels,
163
+ in_channels,
164
+ kernel_size=1,
165
+ stride=1,
166
+ padding=0)
167
+ self.k = torch.nn.Conv2d(in_channels,
168
+ in_channels,
169
+ kernel_size=1,
170
+ stride=1,
171
+ padding=0)
172
+ self.v = torch.nn.Conv2d(in_channels,
173
+ in_channels,
174
+ kernel_size=1,
175
+ stride=1,
176
+ padding=0)
177
+ self.proj_out = torch.nn.Conv2d(in_channels,
178
+ in_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ padding=0)
182
+
183
+ def forward(self, x):
184
+ h_ = x
185
+ h_ = self.norm(h_)
186
+ q = self.q(h_)
187
+ k = self.k(h_)
188
+ v = self.v(h_)
189
+
190
+ # compute attention
191
+ b,c,h,w = q.shape
192
+ q = q.reshape(b,c,h*w)
193
+ q = q.permute(0,2,1) # b,hw,c
194
+ k = k.reshape(b,c,h*w) # b,c,hw
195
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
196
+ w_ = w_ * (int(c)**(-0.5))
197
+ w_ = torch.nn.functional.softmax(w_, dim=2)
198
+
199
+ # attend to values
200
+ v = v.reshape(b,c,h*w)
201
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
202
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
203
+ h_ = h_.reshape(b,c,h,w)
204
+
205
+ h_ = self.proj_out(h_)
206
+
207
+ return x+h_
208
+
209
+ class MemoryEfficientAttnBlock(nn.Module):
210
+ """
211
+ Uses xformers efficient implementation,
212
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
213
+ Note: this is a single-head self-attention operation
214
+ """
215
+ #
216
+ def __init__(self, in_channels, head_dim=None):
217
+ super().__init__()
218
+ self.in_channels = in_channels
219
+ self.head_dim = default(head_dim, in_channels)
220
+ self.heads = in_channels // self.head_dim
221
+ # if self.head_dim > 256:
222
+ # self.attn_processor = F.scaled_dot_product_attention
223
+ # else:
224
+ self.attn_processor = attn_processor
225
+
226
+ self.norm = Normalize(in_channels)
227
+ self.q = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.k = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.v = torch.nn.Conv2d(in_channels,
238
+ in_channels,
239
+ kernel_size=1,
240
+ stride=1,
241
+ padding=0)
242
+ self.proj_out = torch.nn.Conv2d(in_channels,
243
+ in_channels,
244
+ kernel_size=1,
245
+ stride=1,
246
+ padding=0)
247
+ self.attention_op: Optional[Any] = None
248
+
249
+ def forward(self, x):
250
+ h_ = x
251
+ h_ = self.norm(h_)
252
+ q = self.q(h_)
253
+ k = self.k(h_)
254
+ v = self.v(h_)
255
+
256
+ # compute attention
257
+ B, C, H, W = q.shape
258
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
259
+
260
+ q, k, v = map(
261
+ lambda t: t.unsqueeze(3)
262
+ .reshape(B, -1, self.heads, C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B * self.heads, -1, C)
265
+ .contiguous(),
266
+ (q, k, v),
267
+ )
268
+ out = self.attn_processor(q, k, v)
269
+
270
+ out = (
271
+ out.unsqueeze(0)
272
+ .reshape(B, 1, out.shape[1], C)
273
+ .permute(0, 2, 1, 3)
274
+ .reshape(B, out.shape[1], C)
275
+ )
276
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
277
+ out = self.proj_out(out)
278
+ return x+out
279
+
280
+
281
+ def make_attn(in_channels, **kwargs):
282
+ return MemoryEfficientAttnBlock(in_channels)
283
+
284
+
285
+
286
+ class Encoder(nn.Module):
287
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
288
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
289
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
290
+ checkpoint=True, **ignore_kwargs):
291
+ super().__init__()
292
+ if use_linear_attn: attn_type = "linear"
293
+ self.ch = ch
294
+ self.temb_ch = 0
295
+ self.num_resolutions = len(ch_mult)
296
+ self.num_res_blocks = num_res_blocks
297
+ self.resolution = resolution
298
+ self.in_channels = in_channels
299
+
300
+ # downsampling
301
+ self.conv_in = torch.nn.Conv2d(in_channels,
302
+ self.ch,
303
+ kernel_size=3,
304
+ stride=1,
305
+ padding=1)
306
+
307
+ curr_res = resolution
308
+ in_ch_mult = (1,)+tuple(ch_mult)
309
+ self.in_ch_mult = in_ch_mult
310
+ self.down = nn.ModuleList()
311
+ for i_level in range(self.num_resolutions):
312
+ block = nn.ModuleList()
313
+ attn = nn.ModuleList()
314
+ block_in = ch*in_ch_mult[i_level]
315
+ block_out = ch*ch_mult[i_level]
316
+ for i_block in range(self.num_res_blocks):
317
+ block.append(ResnetBlock(in_channels=block_in,
318
+ out_channels=block_out,
319
+ temb_channels=self.temb_ch,
320
+ dropout=dropout))
321
+ block_in = block_out
322
+ if curr_res in attn_resolutions:
323
+ attn.append(make_attn(block_in, attn_type=attn_type))
324
+ down = nn.Module()
325
+ down.block = block
326
+ down.attn = attn
327
+ if i_level != self.num_resolutions-1:
328
+ down.downsample = Downsample(block_in, resamp_with_conv)
329
+ curr_res = curr_res // 2
330
+ self.down.append(down)
331
+
332
+ # middle
333
+ self.mid = nn.Module()
334
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
335
+ out_channels=block_in,
336
+ temb_channels=self.temb_ch,
337
+ dropout=dropout)
338
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
339
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
340
+ out_channels=block_in,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout)
343
+
344
+ # end
345
+ self.norm_out = Normalize(block_in)
346
+ self.conv_out = torch.nn.Conv2d(block_in,
347
+ 2*z_channels if double_z else z_channels,
348
+ kernel_size=3,
349
+ stride=1,
350
+ padding=1)
351
+ self.checkpoint = checkpoint
352
+
353
+ @checkpoint_wrapper
354
+ def forward(self, x):
355
+ # timestep embedding
356
+ temb = None
357
+
358
+ # downsampling
359
+ hs = [self.conv_in(x)]
360
+ for i_level in range(self.num_resolutions):
361
+ for i_block in range(self.num_res_blocks):
362
+ h = self.down[i_level].block[i_block](hs[-1], temb)
363
+ if len(self.down[i_level].attn) > 0:
364
+ h = self.down[i_level].attn[i_block](h)
365
+ hs.append(h)
366
+ if i_level != self.num_resolutions-1:
367
+ hs.append(self.down[i_level].downsample(hs[-1]))
368
+
369
+ # middle
370
+ h = hs[-1]
371
+ h = self.mid.block_1(h, temb)
372
+ h = self.mid.attn_1(h)
373
+ h = self.mid.block_2(h, temb)
374
+
375
+ # end
376
+ h = self.norm_out(h)
377
+ h = nonlinearity(h)
378
+ h = self.conv_out(h)
379
+ return h
380
+
381
+
382
+ class Decoder(nn.Module):
383
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
384
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
385
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
386
+ attn_type="vanilla", checkpoint=True, **ignorekwargs):
387
+ super().__init__()
388
+ if use_linear_attn: attn_type = "linear"
389
+ self.ch = ch
390
+ self.temb_ch = 0
391
+ self.num_resolutions = len(ch_mult)
392
+ self.num_res_blocks = num_res_blocks
393
+ self.resolution = resolution
394
+ self.in_channels = in_channels
395
+ self.give_pre_end = give_pre_end
396
+ self.tanh_out = tanh_out
397
+
398
+ # compute in_ch_mult, block_in and curr_res at lowest res
399
+ in_ch_mult = (1,)+tuple(ch_mult)
400
+ block_in = ch*ch_mult[self.num_resolutions-1]
401
+ curr_res = resolution // 2**(self.num_resolutions-1)
402
+ self.z_shape = (1,z_channels,curr_res,curr_res)
403
+
404
+ # z to block_in
405
+ self.conv_in = torch.nn.Conv2d(z_channels,
406
+ block_in,
407
+ kernel_size=3,
408
+ stride=1,
409
+ padding=1)
410
+
411
+ # middle
412
+ self.mid = nn.Module()
413
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
414
+ out_channels=block_in,
415
+ temb_channels=self.temb_ch,
416
+ dropout=dropout)
417
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
418
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
419
+ out_channels=block_in,
420
+ temb_channels=self.temb_ch,
421
+ dropout=dropout)
422
+
423
+ # upsampling
424
+ self.up = nn.ModuleList()
425
+ for i_level in reversed(range(self.num_resolutions)):
426
+ block = nn.ModuleList()
427
+ attn = nn.ModuleList()
428
+ block_out = ch*ch_mult[i_level]
429
+ for i_block in range(self.num_res_blocks+1):
430
+ block.append(ResnetBlock(in_channels=block_in,
431
+ out_channels=block_out,
432
+ temb_channels=self.temb_ch,
433
+ dropout=dropout))
434
+ block_in = block_out
435
+ if curr_res in attn_resolutions:
436
+ attn.append(make_attn(block_in, attn_type=attn_type))
437
+ up = nn.Module()
438
+ up.block = block
439
+ up.attn = attn
440
+ if i_level != 0:
441
+ up.upsample = Upsample(block_in, resamp_with_conv)
442
+ curr_res = curr_res * 2
443
+ self.up.insert(0, up) # prepend to get consistent order
444
+
445
+ # end
446
+ self.norm_out = Normalize(block_in)
447
+ self.conv_out = torch.nn.Conv2d(block_in,
448
+ out_ch,
449
+ kernel_size=3,
450
+ stride=1,
451
+ padding=1)
452
+ self.checkpoint = checkpoint
453
+
454
+ @checkpoint_wrapper
455
+ def forward(self, z):
456
+ #assert z.shape[1:] == self.z_shape[1:]
457
+ self.last_z_shape = z.shape
458
+
459
+ # timestep embedding
460
+ temb = None
461
+
462
+ # z to block_in
463
+ h = self.conv_in(z)
464
+
465
+ # middle
466
+ h = self.mid.block_1(h, temb)
467
+ h = self.mid.attn_1(h)
468
+ h = self.mid.block_2(h, temb)
469
+
470
+ # upsampling
471
+ for i_level in reversed(range(self.num_resolutions)):
472
+ for i_block in range(self.num_res_blocks+1):
473
+ h = self.up[i_level].block[i_block](h, temb)
474
+ if len(self.up[i_level].attn) > 0:
475
+ h = self.up[i_level].attn[i_block](h)
476
+ if i_level != 0:
477
+ h = self.up[i_level].upsample(h)
478
+
479
+ # end
480
+ if self.give_pre_end:
481
+ return h
482
+
483
+ h = self.norm_out(h)
484
+ h = nonlinearity(h)
485
+ h = self.conv_out(h)
486
+ if self.tanh_out:
487
+ h = torch.tanh(h)
488
+ return h
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
preprocessor/__init__.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch.hub
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms.functional as tf
7
+ import functools
8
+
9
+ model_path = "preprocessor/weights"
10
+ os.environ["HF_HOME"] = model_path
11
+ torch.hub.set_dir(model_path)
12
+
13
+ from torch.hub import download_url_to_file
14
+ from transformers import AutoModelForImageSegmentation
15
+ from .anime2sketch import UnetGenerator
16
+ from .manga_line_extractor import res_skip
17
+ from .sketchKeras import SketchKeras
18
+ from .sk_model import LineartDetector
19
+ from .anime_segment import ISNetDIS
20
+ from refnet.util import load_weights
21
+
22
+
23
+ class NoneMaskExtractor(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.identity = nn.Identity()
27
+
28
+ def proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
29
+ b, c, h, w = x.shape
30
+ return torch.zeros([b, 1, h, w], device=x.device)
31
+
32
+ def forward(self, x):
33
+ return self.proceed(x)
34
+
35
+
36
+ remote_model_dict = {
37
+ "lineart": "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth",
38
+ "lineart_denoise": "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth",
39
+ "lineart_keras": "https://huggingface.co/tellurion/line_extractor/resolve/main/model.pth",
40
+ "lineart_sk": "https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth",
41
+ "ISNet": "https://huggingface.co/tellurion/line_extractor/resolve/main/isnetis.safetensors",
42
+ "ISNet-sketch": "https://huggingface.co/tellurion/line_extractor/resolve/main/sketch-segment.safetensors"
43
+ }
44
+
45
+ BiRefNet_dict = {
46
+ "rmbg-v2": ("briaai/RMBG-2.0", 1024),
47
+ "BiRefNet": ("ZhengPeng7/BiRefNet", 1024),
48
+ "BiRefNet_HR": ("ZhengPeng7/BiRefNet_HR", 2048)
49
+ }
50
+
51
+ def rmbg_proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
52
+ b, c, h, w = x.shape
53
+ x = (x + 1.0) / 2.
54
+ x = tf.resize(x, [self.image_size, self.image_size])
55
+ x = tf.normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
56
+ x = self(x)[-1].sigmoid()
57
+ x = tf.resize(x, [h, w])
58
+
59
+ if th and tw:
60
+ x = tf.pad(x, padding=[(th-h)//2, (tw-w)//2])
61
+ if dilate:
62
+ x = F.max_pool2d(x, kernel_size=21, stride=1, padding=10)
63
+ # x = F.max_pool2d(x, kernel_size=11, stride=1, padding=5)
64
+ # x = mask_expansion(x, 60, 40)
65
+ x = torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x))
66
+ x = x.clamp(0, 1)
67
+ return x
68
+
69
+
70
+
71
+ def create_model(model_name="lineart"):
72
+ """Create a model for anime2sketch
73
+ hardcoding the options for simplicity
74
+ """
75
+ if model_name == "none":
76
+ return NoneMaskExtractor().eval()
77
+
78
+ if model_name in BiRefNet_dict.keys():
79
+ model = AutoModelForImageSegmentation.from_pretrained(
80
+ BiRefNet_dict[model_name][0],
81
+ trust_remote_code = True,
82
+ cache_dir = model_path,
83
+ device_map = None,
84
+ low_cpu_mem_usage = False,
85
+ )
86
+ model.eval()
87
+ model.image_size = BiRefNet_dict[model_name][1]
88
+ model.proceed = rmbg_proceed.__get__(model, model.__class__)
89
+ return model
90
+
91
+ assert model_name in remote_model_dict.keys()
92
+ remote_path = remote_model_dict[model_name]
93
+ basename = os.path.basename(remote_path)
94
+ ckpt_path = os.path.join(model_path, basename)
95
+
96
+ if not os.path.exists(model_path):
97
+ os.makedirs(model_path)
98
+
99
+ if not os.path.exists(ckpt_path):
100
+ cache_path = "preprocessor/weights/weights.tmp"
101
+ download_url_to_file(remote_path, dst=cache_path)
102
+ os.rename(cache_path, ckpt_path)
103
+
104
+ if model_name == "lineart":
105
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
106
+ model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
107
+ elif model_name == "lineart_denoise":
108
+ model = res_skip()
109
+ elif model_name == "lineart_keras":
110
+ model = SketchKeras()
111
+ elif model_name == "lineart_sk":
112
+ model = LineartDetector()
113
+ elif model_name == "ISNet" or model_name == "ISNet-sketch":
114
+ model = ISNetDIS()
115
+ else:
116
+ return None
117
+
118
+ ckpt = load_weights(ckpt_path)
119
+ for key in list(ckpt.keys()):
120
+ if 'module.' in key:
121
+ ckpt[key.replace('module.', '')] = ckpt[key]
122
+ del ckpt[key]
123
+ model.load_state_dict(ckpt)
124
+ return model.eval()
preprocessor/anime2sketch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import functools
4
+ import torchvision.transforms as transforms
5
+
6
+ """
7
+ Anime2Sketch: A sketch extractor for illustration, anime art, manga
8
+ Author: Xiaoyu Zhang
9
+ Github link: https://github.com/Mukosame/Anime2Sketch
10
+ """
11
+
12
+ def to_tensor(x, inverse=False):
13
+ x = transforms.ToTensor()(x).unsqueeze(0)
14
+ x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
15
+ return x if not inverse else -x
16
+
17
+
18
+ class UnetGenerator(nn.Module):
19
+ """Create a Unet-based generator"""
20
+
21
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
22
+ """Construct a Unet generator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ output_nc (int) -- the number of channels in output images
26
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
27
+ image of size 128x128 will become of size 1x1 # at the bottleneck
28
+ ngf (int) -- the number of filters in the last conv layer
29
+ norm_layer -- normalization layer
30
+ We construct the U-Net from the innermost layer to the outermost layer.
31
+ It is a recursive process.
32
+ """
33
+ super(UnetGenerator, self).__init__()
34
+ # construct unet structure
35
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
36
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
37
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
38
+ # gradually reduce the number of filters from ngf * 8 to ngf
39
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
40
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
41
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
42
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
43
+
44
+ def forward(self, input):
45
+ """Standard forward"""
46
+ return self.model(input)
47
+
48
+ def proceed(self, img):
49
+ sketch = self(to_tensor(img))
50
+ return -sketch
51
+
52
+
53
+ class UnetSkipConnectionBlock(nn.Module):
54
+ """Defines the Unet submodule with skip connection.
55
+ X -------------------identity----------------------
56
+ |-- downsampling -- |submodule| -- upsampling --|
57
+ """
58
+
59
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
60
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
61
+ """Construct a Unet submodule with skip connections.
62
+ Parameters:
63
+ outer_nc (int) -- the number of filters in the outer conv layer
64
+ inner_nc (int) -- the number of filters in the inner conv layer
65
+ input_nc (int) -- the number of channels in input images/features
66
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
67
+ outermost (bool) -- if this module is the outermost module
68
+ innermost (bool) -- if this module is the innermost module
69
+ norm_layer -- normalization layer
70
+ use_dropout (bool) -- if use dropout layers.
71
+ """
72
+ super(UnetSkipConnectionBlock, self).__init__()
73
+ self.outermost = outermost
74
+ if type(norm_layer) == functools.partial:
75
+ use_bias = norm_layer.func == nn.InstanceNorm2d
76
+ else:
77
+ use_bias = norm_layer == nn.InstanceNorm2d
78
+ if input_nc is None:
79
+ input_nc = outer_nc
80
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
81
+ stride=2, padding=1, bias=use_bias)
82
+ downrelu = nn.LeakyReLU(0.2, True)
83
+ downnorm = norm_layer(inner_nc)
84
+ uprelu = nn.ReLU(True)
85
+ upnorm = norm_layer(outer_nc)
86
+
87
+ if outermost:
88
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
89
+ kernel_size=4, stride=2,
90
+ padding=1)
91
+ down = [downconv]
92
+ up = [uprelu, upconv, nn.Tanh()]
93
+ model = down + [submodule] + up
94
+ elif innermost:
95
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
96
+ kernel_size=4, stride=2,
97
+ padding=1, bias=use_bias)
98
+ down = [downrelu, downconv]
99
+ up = [uprelu, upconv, upnorm]
100
+ model = down + up
101
+ else:
102
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
103
+ kernel_size=4, stride=2,
104
+ padding=1, bias=use_bias)
105
+ down = [downrelu, downconv, downnorm]
106
+ up = [uprelu, upconv, upnorm]
107
+
108
+ if use_dropout:
109
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
110
+ else:
111
+ model = down + [submodule] + up
112
+
113
+ self.model = nn.Sequential(*model)
114
+
115
+ def forward(self, x):
116
+ if self.outermost:
117
+ return self.model(x).clamp(-1, 1)
118
+ else: # add skip connections
119
+ return torch.cat([x, self.model(x)], 1)
preprocessor/anime_segment.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from refnet.util import default
6
+
7
+ """
8
+ Source code: https://github.com/SkyTNT/anime-segmentation?tab=readme-ov-file
9
+ Author: SkyTNT
10
+ """
11
+
12
+ class REBNCONV(nn.Module):
13
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
14
+ super(REBNCONV, self).__init__()
15
+
16
+ self.conv_s1 = nn.Conv2d(
17
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
18
+ )
19
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
20
+ self.relu_s1 = nn.ReLU(inplace=True)
21
+
22
+ def forward(self, x):
23
+ hx = x
24
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
25
+
26
+ return xout
27
+
28
+
29
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
30
+ def _upsample_like(src, tar):
31
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
32
+
33
+ return src
34
+
35
+
36
+ ### RSU-7 ###
37
+ class RSU7(nn.Module):
38
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
39
+ super(RSU7, self).__init__()
40
+
41
+ self.in_ch = in_ch
42
+ self.mid_ch = mid_ch
43
+ self.out_ch = out_ch
44
+
45
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
46
+
47
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
48
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
58
+
59
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
60
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
61
+
62
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
63
+
64
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
65
+
66
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
71
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
72
+
73
+ def forward(self, x):
74
+ b, c, h, w = x.shape
75
+
76
+ hx = x
77
+ hxin = self.rebnconvin(hx)
78
+
79
+ hx1 = self.rebnconv1(hxin)
80
+ hx = self.pool1(hx1)
81
+
82
+ hx2 = self.rebnconv2(hx)
83
+ hx = self.pool2(hx2)
84
+
85
+ hx3 = self.rebnconv3(hx)
86
+ hx = self.pool3(hx3)
87
+
88
+ hx4 = self.rebnconv4(hx)
89
+ hx = self.pool4(hx4)
90
+
91
+ hx5 = self.rebnconv5(hx)
92
+ hx = self.pool5(hx5)
93
+
94
+ hx6 = self.rebnconv6(hx)
95
+
96
+ hx7 = self.rebnconv7(hx6)
97
+
98
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
99
+ hx6dup = _upsample_like(hx6d, hx5)
100
+
101
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
102
+ hx5dup = _upsample_like(hx5d, hx4)
103
+
104
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
105
+ hx4dup = _upsample_like(hx4d, hx3)
106
+
107
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
108
+ hx3dup = _upsample_like(hx3d, hx2)
109
+
110
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
111
+ hx2dup = _upsample_like(hx2d, hx1)
112
+
113
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
114
+
115
+ return hx1d + hxin
116
+
117
+
118
+ ### RSU-6 ###
119
+ class RSU6(nn.Module):
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
+ super(RSU6, self).__init__()
122
+
123
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
+
125
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+
139
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
+
141
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
+
147
+ def forward(self, x):
148
+ hx = x
149
+
150
+ hxin = self.rebnconvin(hx)
151
+
152
+ hx1 = self.rebnconv1(hxin)
153
+ hx = self.pool1(hx1)
154
+
155
+ hx2 = self.rebnconv2(hx)
156
+ hx = self.pool2(hx2)
157
+
158
+ hx3 = self.rebnconv3(hx)
159
+ hx = self.pool3(hx3)
160
+
161
+ hx4 = self.rebnconv4(hx)
162
+ hx = self.pool4(hx4)
163
+
164
+ hx5 = self.rebnconv5(hx)
165
+
166
+ hx6 = self.rebnconv6(hx5)
167
+
168
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
169
+ hx5dup = _upsample_like(hx5d, hx4)
170
+
171
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
172
+ hx4dup = _upsample_like(hx4d, hx3)
173
+
174
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
175
+ hx3dup = _upsample_like(hx3d, hx2)
176
+
177
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
178
+ hx2dup = _upsample_like(hx2d, hx1)
179
+
180
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
181
+
182
+ return hx1d + hxin
183
+
184
+
185
+ ### RSU-5 ###
186
+ class RSU5(nn.Module):
187
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
188
+ super(RSU5, self).__init__()
189
+
190
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
191
+
192
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
193
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
200
+
201
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
202
+
203
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
204
+
205
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
206
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
209
+
210
+ def forward(self, x):
211
+ hx = x
212
+
213
+ hxin = self.rebnconvin(hx)
214
+
215
+ hx1 = self.rebnconv1(hxin)
216
+ hx = self.pool1(hx1)
217
+
218
+ hx2 = self.rebnconv2(hx)
219
+ hx = self.pool2(hx2)
220
+
221
+ hx3 = self.rebnconv3(hx)
222
+ hx = self.pool3(hx3)
223
+
224
+ hx4 = self.rebnconv4(hx)
225
+
226
+ hx5 = self.rebnconv5(hx4)
227
+
228
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
229
+ hx4dup = _upsample_like(hx4d, hx3)
230
+
231
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
232
+ hx3dup = _upsample_like(hx3d, hx2)
233
+
234
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
235
+ hx2dup = _upsample_like(hx2d, hx1)
236
+
237
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
238
+
239
+ return hx1d + hxin
240
+
241
+
242
+ ### RSU-4 ###
243
+ class RSU4(nn.Module):
244
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
245
+ super(RSU4, self).__init__()
246
+
247
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
248
+
249
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
250
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
254
+
255
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
256
+
257
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
258
+
259
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
260
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
261
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
262
+
263
+ def forward(self, x):
264
+ hx = x
265
+
266
+ hxin = self.rebnconvin(hx)
267
+
268
+ hx1 = self.rebnconv1(hxin)
269
+ hx = self.pool1(hx1)
270
+
271
+ hx2 = self.rebnconv2(hx)
272
+ hx = self.pool2(hx2)
273
+
274
+ hx3 = self.rebnconv3(hx)
275
+
276
+ hx4 = self.rebnconv4(hx3)
277
+
278
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
279
+ hx3dup = _upsample_like(hx3d, hx2)
280
+
281
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
282
+ hx2dup = _upsample_like(hx2d, hx1)
283
+
284
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
285
+
286
+ return hx1d + hxin
287
+
288
+
289
+ ### RSU-4F ###
290
+ class RSU4F(nn.Module):
291
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292
+ super(RSU4F, self).__init__()
293
+
294
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
295
+
296
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
297
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
298
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
299
+
300
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
301
+
302
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
303
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
304
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
305
+
306
+ def forward(self, x):
307
+ hx = x
308
+
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class myrebnconv(nn.Module):
325
+ def __init__(
326
+ self,
327
+ in_ch=3,
328
+ out_ch=1,
329
+ kernel_size=3,
330
+ stride=1,
331
+ padding=1,
332
+ dilation=1,
333
+ groups=1,
334
+ ):
335
+ super(myrebnconv, self).__init__()
336
+
337
+ self.conv = nn.Conv2d(
338
+ in_ch,
339
+ out_ch,
340
+ kernel_size=kernel_size,
341
+ stride=stride,
342
+ padding=padding,
343
+ dilation=dilation,
344
+ groups=groups,
345
+ )
346
+ self.bn = nn.BatchNorm2d(out_ch)
347
+ self.rl = nn.ReLU(inplace=True)
348
+
349
+ def forward(self, x):
350
+ return self.rl(self.bn(self.conv(x)))
351
+
352
+
353
+ class ISNetDIS(nn.Module):
354
+ def __init__(self, in_ch=3, out_ch=1):
355
+ super(ISNetDIS, self).__init__()
356
+
357
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
358
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
359
+
360
+ self.stage1 = RSU7(64, 32, 64)
361
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
362
+
363
+ self.stage2 = RSU6(64, 32, 128)
364
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
365
+
366
+ self.stage3 = RSU5(128, 64, 256)
367
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
368
+
369
+ self.stage4 = RSU4(256, 128, 512)
370
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
371
+
372
+ self.stage5 = RSU4F(512, 256, 512)
373
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
374
+
375
+ self.stage6 = RSU4F(512, 256, 512)
376
+
377
+ # decoder
378
+ self.stage5d = RSU4F(1024, 256, 512)
379
+ self.stage4d = RSU4(1024, 128, 256)
380
+ self.stage3d = RSU5(512, 64, 128)
381
+ self.stage2d = RSU6(256, 32, 64)
382
+ self.stage1d = RSU7(128, 16, 64)
383
+
384
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
385
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
386
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
387
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
388
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
389
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
390
+
391
+ def forward(self, x):
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ hx = self.pool_in(hxin)
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+ d1 = _upsample_like(d1, x)
439
+
440
+ # d2 = self.side2(hx2d)
441
+ # d2 = _upsample_like(d2, x)
442
+ #
443
+ # d3 = self.side3(hx3d)
444
+ # d3 = _upsample_like(d3, x)
445
+ #
446
+ # d4 = self.side4(hx4d)
447
+ # d4 = _upsample_like(d4, x)
448
+ #
449
+ # d5 = self.side5(hx5d)
450
+ # d5 = _upsample_like(d5, x)
451
+ #
452
+ # d6 = self.side6(hx6)
453
+ # d6 = _upsample_like(d6, x)
454
+
455
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
456
+ #
457
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
458
+ # return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
459
+ return torch.sigmoid(d1)
460
+
461
+ def proceed(self, x: torch.Tensor, th=None, tw=None, s=1024, dilate=False, crop=True, *args, **kwargs):
462
+ b, c, h, w = x.shape
463
+
464
+ if crop:
465
+ th, tw = default(th, h), default(tw, w)
466
+ scale = s / max(h, w)
467
+ h, w = int(h * scale), int(w * scale)
468
+
469
+ canvas = -torch.ones((b, c, s, s), dtype=x.dtype, device=x.device)
470
+ ph, pw = (s - h) // 2, (s - w) // 2
471
+ x = F.interpolate(x, scale_factor=scale, mode="bicubic")
472
+
473
+ canvas[:, :, ph: ph+h, pw: pw+w] = x
474
+
475
+ canvas = 1 - (canvas + 1.) / 2.
476
+ mask = self(canvas)[:, :, ph: ph+h, pw: pw+w]
477
+
478
+ else:
479
+ x = F.interpolate(x, size=(s, s), mode="bicubic")
480
+ mask = self(x)
481
+
482
+ mask = F.interpolate(mask, (th, tw), mode="bicubic").clamp(0, 1)
483
+
484
+ if dilate:
485
+ mask = F.max_pool2d(mask, kernel_size=21, stride=1, padding=10)
486
+ # mask = mask_expansion(mask, 32, 20)
487
+ return mask
preprocessor/manga_line_extractor.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.transforms as transforms
3
+
4
+
5
+ class _bn_relu_conv(nn.Module):
6
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
7
+ super(_bn_relu_conv, self).__init__()
8
+ self.model = nn.Sequential(
9
+ nn.BatchNorm2d(in_filters, eps=1e-3),
10
+ nn.LeakyReLU(0.2),
11
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
12
+ )
13
+
14
+ def forward(self, x):
15
+ return self.model(x)
16
+
17
+
18
+ class _u_bn_relu_conv(nn.Module):
19
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
20
+ super(_u_bn_relu_conv, self).__init__()
21
+ self.model = nn.Sequential(
22
+ nn.BatchNorm2d(in_filters, eps=1e-3),
23
+ nn.LeakyReLU(0.2),
24
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
25
+ nn.Upsample(scale_factor=2, mode='nearest')
26
+ )
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+
32
+
33
+ class _shortcut(nn.Module):
34
+ def __init__(self, in_filters, nb_filters, subsample=1):
35
+ super(_shortcut, self).__init__()
36
+ self.process = False
37
+ self.model = None
38
+ if in_filters != nb_filters or subsample != 1:
39
+ self.process = True
40
+ self.model = nn.Sequential(
41
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
42
+ )
43
+
44
+ def forward(self, x, y):
45
+ #print(x.size(), y.size(), self.process)
46
+ if self.process:
47
+ y0 = self.model(x)
48
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
49
+ return y0 + y
50
+ else:
51
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
52
+ return x + y
53
+
54
+ class _u_shortcut(nn.Module):
55
+ def __init__(self, in_filters, nb_filters, subsample):
56
+ super(_u_shortcut, self).__init__()
57
+ self.process = False
58
+ self.model = None
59
+ if in_filters != nb_filters:
60
+ self.process = True
61
+ self.model = nn.Sequential(
62
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
63
+ nn.Upsample(scale_factor=2, mode='nearest')
64
+ )
65
+
66
+ def forward(self, x, y):
67
+ if self.process:
68
+ return self.model(x) + y
69
+ else:
70
+ return x + y
71
+
72
+
73
+ class basic_block(nn.Module):
74
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
75
+ super(basic_block, self).__init__()
76
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
77
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
78
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
79
+
80
+ def forward(self, x):
81
+ x1 = self.conv1(x)
82
+ x2 = self.residual(x1)
83
+ return self.shortcut(x, x2)
84
+
85
+ class _u_basic_block(nn.Module):
86
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
87
+ super(_u_basic_block, self).__init__()
88
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
89
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
90
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
91
+
92
+ def forward(self, x):
93
+ y = self.residual(self.conv1(x))
94
+ return self.shortcut(x, y)
95
+
96
+
97
+ class _residual_block(nn.Module):
98
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
99
+ super(_residual_block, self).__init__()
100
+ layers = []
101
+ for i in range(repetitions):
102
+ init_subsample = 1
103
+ if i == repetitions - 1 and not is_first_layer:
104
+ init_subsample = 2
105
+ if i == 0:
106
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
107
+ else:
108
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
109
+ layers.append(l)
110
+
111
+ self.model = nn.Sequential(*layers)
112
+
113
+ def forward(self, x):
114
+ return self.model(x)
115
+
116
+
117
+ class _upsampling_residual_block(nn.Module):
118
+ def __init__(self, in_filters, nb_filters, repetitions):
119
+ super(_upsampling_residual_block, self).__init__()
120
+ layers = []
121
+ for i in range(repetitions):
122
+ l = None
123
+ if i == 0:
124
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
125
+ else:
126
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
127
+ layers.append(l)
128
+
129
+ self.model = nn.Sequential(*layers)
130
+
131
+ def forward(self, x):
132
+ return self.model(x)
133
+
134
+ class res_skip(nn.Module):
135
+
136
+ def __init__(self):
137
+ super(res_skip, self).__init__()
138
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) # (input)
139
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) # (block0)
140
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) # (block1)
141
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) # (block2)
142
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) # (block3)
143
+
144
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) # (block4)
145
+ self.res1 = _shortcut(in_filters=192, nb_filters=192) # (block3, block5, subsample=(1,1))
146
+
147
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) # (res1)
148
+ self.res2 = _shortcut(in_filters=96, nb_filters=96) # (block2, block6, subsample=(1,1))
149
+
150
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) # (res2)
151
+ self.res3 = _shortcut(in_filters=48, nb_filters=48) # (block1, block7, subsample=(1,1))
152
+
153
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) # (res3)
154
+ self.res4 = _shortcut(in_filters=24, nb_filters=24) # (block0,block8, subsample=(1,1))
155
+
156
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) # (res4)
157
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) # (block7)
158
+
159
+ def forward(self, x):
160
+ x0 = self.block0(x)
161
+ x1 = self.block1(x0)
162
+ x2 = self.block2(x1)
163
+ x3 = self.block3(x2)
164
+ x4 = self.block4(x3)
165
+
166
+ x5 = self.block5(x4)
167
+ res1 = self.res1(x3, x5)
168
+
169
+ x6 = self.block6(res1)
170
+ res2 = self.res2(x2, x6)
171
+
172
+ x7 = self.block7(res2)
173
+ res3 = self.res3(x1, x7)
174
+
175
+ x8 = self.block8(res3)
176
+ res4 = self.res4(x0, x8)
177
+
178
+ x9 = self.block9(res4)
179
+ y = self.conv15(x9)
180
+
181
+ return y
182
+
183
+ def proceed(self, sketch):
184
+ sketch = transforms.ToTensor()(sketch).unsqueeze(0)[:, 0] * 255
185
+ sketch = sketch.unsqueeze(1).cuda()
186
+ sketch = self(sketch) / 127.5 - 1
187
+ return -sketch.clamp(-1, 1)
preprocessor/sk_model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.transforms.functional as tf
3
+
4
+
5
+ norm_layer = nn.InstanceNorm2d
6
+
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, in_features):
9
+ super(ResidualBlock, self).__init__()
10
+
11
+ conv_block = [ nn.ReflectionPad2d(1),
12
+ nn.Conv2d(in_features, in_features, 3),
13
+ norm_layer(in_features),
14
+ nn.ReLU(inplace=True),
15
+ nn.ReflectionPad2d(1),
16
+ nn.Conv2d(in_features, in_features, 3),
17
+ norm_layer(in_features)
18
+ ]
19
+
20
+ self.conv_block = nn.Sequential(*conv_block)
21
+
22
+ def forward(self, x):
23
+ return x + self.conv_block(x)
24
+
25
+
26
+ class Generator(nn.Module):
27
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
28
+ super(Generator, self).__init__()
29
+
30
+ # Initial convolution block
31
+ model0 = [ nn.ReflectionPad2d(3),
32
+ nn.Conv2d(input_nc, 64, 7),
33
+ norm_layer(64),
34
+ nn.ReLU(inplace=True) ]
35
+ self.model0 = nn.Sequential(*model0)
36
+
37
+ # Downsampling
38
+ model1 = []
39
+ in_features = 64
40
+ out_features = in_features*2
41
+ for _ in range(2):
42
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
43
+ norm_layer(out_features),
44
+ nn.ReLU(inplace=True) ]
45
+ in_features = out_features
46
+ out_features = in_features*2
47
+ self.model1 = nn.Sequential(*model1)
48
+
49
+ model2 = []
50
+ # Residual blocks
51
+ for _ in range(n_residual_blocks):
52
+ model2 += [ResidualBlock(in_features)]
53
+ self.model2 = nn.Sequential(*model2)
54
+
55
+ # Upsampling
56
+ model3 = []
57
+ out_features = in_features//2
58
+ for _ in range(2):
59
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
60
+ norm_layer(out_features),
61
+ nn.ReLU(inplace=True) ]
62
+ in_features = out_features
63
+ out_features = in_features//2
64
+ self.model3 = nn.Sequential(*model3)
65
+
66
+ # Output layer
67
+ model4 = [ nn.ReflectionPad2d(3),
68
+ nn.Conv2d(64, output_nc, 7)]
69
+ if sigmoid:
70
+ model4 += [nn.Sigmoid()]
71
+
72
+ self.model4 = nn.Sequential(*model4)
73
+
74
+ def forward(self, x, cond=None):
75
+ out = self.model0(x)
76
+ out = self.model1(out)
77
+ out = self.model2(out)
78
+ out = self.model3(out)
79
+ out = self.model4(out)
80
+
81
+ return out
82
+
83
+
84
+ class LineartDetector(nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ self.model = Generator(3, 1, 3)
88
+
89
+ def load_state_dict(self, sd):
90
+ self.model.load_state_dict(sd)
91
+
92
+ def proceed(self, sketch):
93
+ sketch = tf.pil_to_tensor(sketch).unsqueeze(0).cuda().float()
94
+ return -self.model(sketch)
preprocessor/sketchKeras.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def postprocess(pred, thresh=0.18):
9
+ assert thresh <= 1.0 and thresh >= 0.0
10
+
11
+ pred = torch.amax(pred, 0)
12
+ pred[pred < thresh] = 0
13
+ pred -= 0.5
14
+ pred *= 2
15
+ return pred
16
+
17
+
18
+ class SketchKeras(nn.Module):
19
+ def __init__(self):
20
+ super(SketchKeras, self).__init__()
21
+
22
+ self.downblock_1 = nn.Sequential(
23
+ nn.ReflectionPad2d((1, 1, 1, 1)),
24
+ nn.Conv2d(1, 32, kernel_size=3, stride=1),
25
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
26
+ nn.ReLU(),
27
+ )
28
+ self.downblock_2 = nn.Sequential(
29
+ nn.ReflectionPad2d((1, 1, 1, 1)),
30
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
31
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
32
+ nn.ReLU(),
33
+ nn.ReflectionPad2d((1, 1, 1, 1)),
34
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
35
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
36
+ nn.ReLU(),
37
+ )
38
+ self.downblock_3 = nn.Sequential(
39
+ nn.ReflectionPad2d((1, 1, 1, 1)),
40
+ nn.Conv2d(64, 128, kernel_size=4, stride=2),
41
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
42
+ nn.ReLU(),
43
+ nn.ReflectionPad2d((1, 1, 1, 1)),
44
+ nn.Conv2d(128, 128, kernel_size=3, stride=1),
45
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
46
+ nn.ReLU(),
47
+ )
48
+ self.downblock_4 = nn.Sequential(
49
+ nn.ReflectionPad2d((1, 1, 1, 1)),
50
+ nn.Conv2d(128, 256, kernel_size=4, stride=2),
51
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
52
+ nn.ReLU(),
53
+ nn.ReflectionPad2d((1, 1, 1, 1)),
54
+ nn.Conv2d(256, 256, kernel_size=3, stride=1),
55
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
56
+ nn.ReLU(),
57
+ )
58
+ self.downblock_5 = nn.Sequential(
59
+ nn.ReflectionPad2d((1, 1, 1, 1)),
60
+ nn.Conv2d(256, 512, kernel_size=4, stride=2),
61
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
62
+ nn.ReLU(),
63
+ )
64
+ self.downblock_6 = nn.Sequential(
65
+ nn.ReflectionPad2d((1, 1, 1, 1)),
66
+ nn.Conv2d(512, 512, kernel_size=3, stride=1),
67
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
68
+ nn.ReLU(),
69
+ )
70
+
71
+ self.upblock_1 = nn.Sequential(
72
+ nn.Upsample(scale_factor=2, mode="bicubic"),
73
+ nn.ReflectionPad2d((1, 2, 1, 2)),
74
+ nn.Conv2d(1024, 512, kernel_size=4, stride=1),
75
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
76
+ nn.ReLU(),
77
+ nn.ReflectionPad2d((1, 1, 1, 1)),
78
+ nn.Conv2d(512, 256, kernel_size=3, stride=1),
79
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
80
+ nn.ReLU(),
81
+ )
82
+
83
+ self.upblock_2 = nn.Sequential(
84
+ nn.Upsample(scale_factor=2, mode="bicubic"),
85
+ nn.ReflectionPad2d((1, 2, 1, 2)),
86
+ nn.Conv2d(512, 256, kernel_size=4, stride=1),
87
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
88
+ nn.ReLU(),
89
+ nn.ReflectionPad2d((1, 1, 1, 1)),
90
+ nn.Conv2d(256, 128, kernel_size=3, stride=1),
91
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
92
+ nn.ReLU(),
93
+ )
94
+
95
+ self.upblock_3 = nn.Sequential(
96
+ nn.Upsample(scale_factor=2, mode="bicubic"),
97
+ nn.ReflectionPad2d((1, 2, 1, 2)),
98
+ nn.Conv2d(256, 128, kernel_size=4, stride=1),
99
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
100
+ nn.ReLU(),
101
+ nn.ReflectionPad2d((1, 1, 1, 1)),
102
+ nn.Conv2d(128, 64, kernel_size=3, stride=1),
103
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
104
+ nn.ReLU(),
105
+ )
106
+
107
+ self.upblock_4 = nn.Sequential(
108
+ nn.Upsample(scale_factor=2, mode="bicubic"),
109
+ nn.ReflectionPad2d((1, 2, 1, 2)),
110
+ nn.Conv2d(128, 64, kernel_size=4, stride=1),
111
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
112
+ nn.ReLU(),
113
+ nn.ReflectionPad2d((1, 1, 1, 1)),
114
+ nn.Conv2d(64, 32, kernel_size=3, stride=1),
115
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
116
+ nn.ReLU(),
117
+ )
118
+
119
+ self.last_pad = nn.ReflectionPad2d((1, 1, 1, 1))
120
+ self.last_conv = nn.Conv2d(64, 1, kernel_size=3, stride=1)
121
+
122
+ def forward(self, x):
123
+ d1 = self.downblock_1(x)
124
+ d2 = self.downblock_2(d1)
125
+ d3 = self.downblock_3(d2)
126
+ d4 = self.downblock_4(d3)
127
+ d5 = self.downblock_5(d4)
128
+ d6 = self.downblock_6(d5)
129
+
130
+ u1 = torch.cat((d5, d6), dim=1)
131
+ u1 = self.upblock_1(u1)
132
+ u2 = torch.cat((d4, u1), dim=1)
133
+ u2 = self.upblock_2(u2)
134
+ u3 = torch.cat((d3, u2), dim=1)
135
+ u3 = self.upblock_3(u3)
136
+ u4 = torch.cat((d2, u3), dim=1)
137
+ u4 = self.upblock_4(u4)
138
+ u5 = torch.cat((d1, u4), dim=1)
139
+
140
+ out = self.last_conv(self.last_pad(u5))
141
+
142
+ return out
143
+
144
+ def proceed(self, img):
145
+ img = np.array(img)
146
+ blurred = cv2.GaussianBlur(img, (0, 0), 3)
147
+ img = img.astype(int) - blurred.astype(int)
148
+ img = img.astype(np.float32) / 127.5
149
+ img /= np.max(img)
150
+ img = torch.tensor(img).unsqueeze(0).permute(3, 0, 1, 2).cuda()
151
+ img = self(img)
152
+ img = postprocess(img, thresh=0.1).unsqueeze(1)
153
+ return img
refnet/__init__.py ADDED
File without changes
refnet/ldm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ddpm import LatentDiffusion
refnet/ldm/ddpm.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from contextlib import contextmanager
13
+ from functools import partial
14
+
15
+ from refnet.util import default, count_params, instantiate_from_config, exists
16
+ from refnet.ldm.util import make_beta_schedule, extract_into_tensor
17
+
18
+
19
+
20
+ def disabled_train(self, mode=True):
21
+ """Overwrite model.train with this function to make sure train/eval mode
22
+ does not change anymore."""
23
+ return self
24
+
25
+
26
+ def uniform_on_device(r1, r2, shape, device):
27
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
28
+
29
+
30
+ def rescale_zero_terminal_snr(betas):
31
+ """
32
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
33
+
34
+
35
+ Args:
36
+ betas (`torch.FloatTensor`):
37
+ the betas that the scheduler is being initialized with.
38
+
39
+ Returns:
40
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
41
+ """
42
+ # Convert betas to alphas_bar_sqrt
43
+ alphas = 1.0 - betas
44
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
45
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
46
+
47
+ # Store old values.
48
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
49
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
50
+
51
+ # Shift so the last timestep is zero.
52
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
53
+
54
+ # Scale so the first timestep is back to the old value.
55
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
56
+
57
+ # Convert alphas_bar_sqrt to betas
58
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
59
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
60
+ alphas = torch.cat([alphas_bar[0:1], alphas])
61
+ betas = 1 - alphas
62
+
63
+ return betas
64
+
65
+
66
+ class DDPM(nn.Module):
67
+ # classic DDPM with Gaussian diffusion, in image space
68
+ def __init__(
69
+ self,
70
+ unet_config,
71
+ timesteps = 1000,
72
+ beta_schedule = "scaled_linear",
73
+ image_size = 256,
74
+ channels = 3,
75
+ linear_start = 1e-4,
76
+ linear_end = 2e-2,
77
+ cosine_s = 8e-3,
78
+ v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
79
+ parameterization = "eps", # all assuming fixed variance schedules
80
+ zero_snr = False,
81
+ half_precision_dtype = "float16",
82
+ version = "sdv1",
83
+ *args,
84
+ **kwargs
85
+ ):
86
+ super().__init__()
87
+ assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'"
88
+ assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default"
89
+ if zero_snr:
90
+ assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.'
91
+
92
+ self.is_sdxl = (version == "sdxl")
93
+ self.parameterization = parameterization
94
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
95
+ self.cond_stage_model = None
96
+ self.img_embedder = None
97
+ self.image_size = image_size # try conv?
98
+ self.channels = channels
99
+ self.model = DiffusionWrapper(unet_config)
100
+ count_params(self.model, verbose=True)
101
+ self.v_posterior = v_posterior
102
+ self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16
103
+ self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
104
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr)
105
+
106
+
107
+ def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000,
108
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
109
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
110
+ cosine_s=cosine_s, zero_snr=zero_snr)
111
+
112
+ alphas = 1. - betas
113
+ alphas_cumprod = np.cumprod(alphas, axis=0)
114
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
115
+
116
+ timesteps, = betas.shape
117
+ self.num_timesteps = int(timesteps)
118
+ self.linear_start = linear_start
119
+ self.linear_end = linear_end
120
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
121
+
122
+ to_torch = partial(torch.tensor, dtype=torch.float32)
123
+
124
+ self.register_buffer('betas', to_torch(betas))
125
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
126
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
127
+
128
+ # calculations for diffusion q(x_t | x_{t-1}) and others
129
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
130
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
131
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
132
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
133
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
134
+
135
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
136
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
137
+ 1. - alphas_cumprod) + self.v_posterior * betas
138
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
139
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
140
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
141
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
142
+ self.register_buffer('posterior_mean_coef1', to_torch(
143
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
144
+ self.register_buffer('posterior_mean_coef2', to_torch(
145
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
146
+
147
+
148
+ @contextmanager
149
+ def ema_scope(self, context=None):
150
+ if self.use_ema:
151
+ self.model_ema.store(self.model.parameters())
152
+ self.model_ema.copy_to(self.model)
153
+ if context is not None:
154
+ print(f"{context}: Switched to EMA weights")
155
+ try:
156
+ yield None
157
+ finally:
158
+ if self.use_ema:
159
+ self.model_ema.restore(self.model.parameters())
160
+ if context is not None:
161
+ print(f"{context}: Restored training weights")
162
+
163
+
164
+ def predict_start_from_z_and_v(self, x_t, t, v):
165
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
166
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
167
+ return (
168
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
169
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
170
+ )
171
+
172
+ def add_noise(self, x_start, t, noise=None):
173
+ noise = default(noise, lambda: torch.randn_like(x_start))
174
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
175
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype)
176
+
177
+ def get_v(self, x, noise, t):
178
+ return (
179
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
180
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
181
+ )
182
+
183
+ def normalize_timesteps(self, timesteps):
184
+ return timesteps
185
+
186
+
187
+ class LatentDiffusion(DDPM):
188
+ """main class"""
189
+
190
+ def __init__(
191
+ self,
192
+ first_stage_config,
193
+ cond_stage_config,
194
+ scale_factor = 1.0,
195
+ *args,
196
+ **kwargs
197
+ ):
198
+ super().__init__(*args, **kwargs)
199
+ self.scale_factor = scale_factor
200
+ self.first_stage_model, self.cond_stage_model = map(
201
+ lambda t: instantiate_from_config(t).eval().requires_grad_(False),
202
+ (first_stage_config, cond_stage_config)
203
+ )
204
+
205
+ @torch.no_grad()
206
+ def get_first_stage_encoding(self, x):
207
+ encoder_posterior = self.first_stage_model.encode(x)
208
+ z = encoder_posterior.sample() * self.scale_factor
209
+ return z.to(self.dtype).detach()
210
+
211
+ @torch.no_grad()
212
+ def decode_first_stage(self, z):
213
+ z = 1. / self.scale_factor * z
214
+ return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach()
215
+
216
+ def apply_model(self, x_noisy, t, cond):
217
+ return self.model(x_noisy, t, **cond)
218
+
219
+ def get_learned_embedding(self, c, *args, **kwargs):
220
+ wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs))
221
+ clip_emb = self.cond_stage_model.encode(c, **kwargs).detach()
222
+ return wd_emb, wd_logits, clip_emb
223
+
224
+
225
+ class DiffusionWrapper(nn.Module):
226
+ def __init__(self, diff_model_config):
227
+ super().__init__()
228
+ self.diffusion_model = instantiate_from_config(diff_model_config)
229
+
230
+ def forward(self, x, t, **cond):
231
+ for k in cond:
232
+ if k in ["context", "y", "concat"]:
233
+ cond[k] = torch.cat(cond[k], 1)
234
+
235
+ out = self.diffusion_model(x, t, **cond)
236
+ return out
refnet/ldm/openaimodel.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from refnet.ldm.util import (
10
+ conv_nd,
11
+ linear,
12
+ avg_pool_nd,
13
+ zero_module,
14
+ normalization,
15
+ timestep_embedding,
16
+ )
17
+ from refnet.util import checkpoint_wrapper
18
+
19
+
20
+
21
+ # dummy replace
22
+ def convert_module_to_f16(x):
23
+ pass
24
+
25
+ def convert_module_to_f32(x):
26
+ pass
27
+
28
+
29
+ ## go
30
+ class AttentionPool2d(nn.Module):
31
+ """
32
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ spacial_dim: int,
38
+ embed_dim: int,
39
+ num_heads_channels: int,
40
+ output_dim: int = None,
41
+ ):
42
+ super().__init__()
43
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
44
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
45
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
46
+ self.num_heads = embed_dim // num_heads_channels
47
+ self.attention = QKVAttention(self.num_heads)
48
+
49
+ def forward(self, x):
50
+ b, c, *_spatial = x.shape
51
+ x = x.reshape(b, c, -1) # NC(HW)
52
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
53
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
54
+ x = self.qkv_proj(x)
55
+ x = self.attention(x)
56
+ x = self.c_proj(x)
57
+ return x[:, :, 0]
58
+
59
+
60
+ class TimestepBlock(nn.Module):
61
+ """
62
+ Any module where forward() takes timestep embeddings as a second argument.
63
+ """
64
+
65
+ @abstractmethod
66
+ def forward(self, x, emb):
67
+ """
68
+ Apply the module to `x` given `emb` timestep embeddings.
69
+ """
70
+
71
+
72
+ class Upsample(nn.Module):
73
+ """
74
+ An upsampling layer with an optional convolution.
75
+ :param channels: channels in the inputs and outputs.
76
+ :param use_conv: a bool determining if a convolution is applied.
77
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
78
+ upsampling occurs in the inner-two dimensions.
79
+ """
80
+
81
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
82
+ super().__init__()
83
+ self.channels = channels
84
+ self.out_channels = out_channels or channels
85
+ self.use_conv = use_conv
86
+ self.dims = dims
87
+ if use_conv:
88
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
89
+
90
+ def forward(self, x):
91
+ assert x.shape[1] == self.channels
92
+ if self.dims == 3:
93
+ x = F.interpolate(
94
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
95
+ )
96
+ else:
97
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
98
+ if self.use_conv:
99
+ x = self.conv(x)
100
+ return x
101
+
102
+ class TransposedUpsample(nn.Module):
103
+ 'Learned 2x upsampling without padding'
104
+ def __init__(self, channels, out_channels=None, ks=5):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.out_channels = out_channels or channels
108
+
109
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
110
+
111
+ def forward(self,x):
112
+ return self.up(x)
113
+
114
+
115
+ class Downsample(nn.Module):
116
+ """
117
+ A downsampling layer with an optional convolution.
118
+ :param channels: channels in the inputs and outputs.
119
+ :param use_conv: a bool determining if a convolution is applied.
120
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
121
+ downsampling occurs in the inner-two dimensions.
122
+ """
123
+
124
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
125
+ super().__init__()
126
+ self.channels = channels
127
+ self.out_channels = out_channels or channels
128
+ self.use_conv = use_conv
129
+ self.dims = dims
130
+ stride = 2 if dims != 3 else (1, 2, 2)
131
+ if use_conv:
132
+ self.op = conv_nd(
133
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
134
+ )
135
+ else:
136
+ assert self.channels == self.out_channels
137
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
138
+
139
+ def forward(self, x):
140
+ assert x.shape[1] == self.channels
141
+ return self.op(x)
142
+
143
+
144
+ class ResBlock(TimestepBlock):
145
+ """
146
+ A residual block that can optionally change the number of channels.
147
+ :param channels: the number of input channels.
148
+ :param emb_channels: the number of timestep embedding channels.
149
+ :param dropout: the rate of dropout.
150
+ :param out_channels: if specified, the number of out channels.
151
+ :param use_conv: if True and out_channels is specified, use a spatial
152
+ convolution instead of a smaller 1x1 convolution to change the
153
+ channels in the skip connection.
154
+ :param dims: determines if the signal is 1D, 2D, or 3D.
155
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
156
+ :param up: if True, use this block for upsampling.
157
+ :param down: if True, use this block for downsampling.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ channels,
163
+ emb_channels,
164
+ dropout,
165
+ out_channels=None,
166
+ use_conv=False,
167
+ use_scale_shift_norm=False,
168
+ dims=2,
169
+ use_checkpoint=False,
170
+ up=False,
171
+ down=False,
172
+ ):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.emb_channels = emb_channels
176
+ self.dropout = dropout
177
+ self.out_channels = out_channels or channels
178
+ self.use_conv = use_conv
179
+ self.checkpoint = use_checkpoint
180
+ self.use_scale_shift_norm = use_scale_shift_norm
181
+
182
+ self.in_layers = nn.Sequential(
183
+ normalization(channels),
184
+ nn.SiLU(),
185
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
186
+ )
187
+
188
+ self.updown = up or down
189
+
190
+ if up:
191
+ self.h_upd = Upsample(channels, False, dims)
192
+ self.x_upd = Upsample(channels, False, dims)
193
+ elif down:
194
+ self.h_upd = Downsample(channels, False, dims)
195
+ self.x_upd = Downsample(channels, False, dims)
196
+ else:
197
+ self.h_upd = self.x_upd = nn.Identity()
198
+
199
+ self.emb_layers = nn.Sequential(
200
+ nn.SiLU(),
201
+ linear(
202
+ emb_channels,
203
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
204
+ ),
205
+ )
206
+ self.out_layers = nn.Sequential(
207
+ normalization(self.out_channels),
208
+ nn.SiLU(),
209
+ nn.Dropout(p=dropout),
210
+ zero_module(
211
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
212
+ ),
213
+ )
214
+
215
+ if self.out_channels == channels:
216
+ self.skip_connection = nn.Identity()
217
+ elif use_conv:
218
+ self.skip_connection = conv_nd(
219
+ dims, channels, self.out_channels, 3, padding=1
220
+ )
221
+ else:
222
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
223
+
224
+ @checkpoint_wrapper
225
+ def forward(self, x, emb):
226
+ if self.updown:
227
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
228
+ h = in_rest(x)
229
+ h = self.h_upd(h)
230
+ x = self.x_upd(x)
231
+ h = in_conv(h)
232
+ else:
233
+ h = self.in_layers(x)
234
+ emb_out = self.emb_layers(emb).type(h.dtype)
235
+ while len(emb_out.shape) < len(h.shape):
236
+ emb_out = emb_out[..., None]
237
+ if self.use_scale_shift_norm:
238
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
239
+ scale, shift = th.chunk(emb_out, 2, dim=1)
240
+ h = out_norm(h) * (1 + scale) + shift
241
+ h = out_rest(h)
242
+ else:
243
+ h = h + emb_out
244
+ h = self.out_layers(h)
245
+ return self.skip_connection(x) + h
246
+
247
+
248
+ class AttentionBlock(nn.Module):
249
+ """
250
+ An attention block that allows spatial positions to attend to each other.
251
+ Originally ported from here, but adapted to the N-d case.
252
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ channels,
258
+ num_heads=1,
259
+ num_head_channels=-1,
260
+ use_checkpoint=False,
261
+ use_new_attention_order=False,
262
+ ):
263
+ super().__init__()
264
+ self.channels = channels
265
+ if num_head_channels == -1:
266
+ self.num_heads = num_heads
267
+ else:
268
+ assert (
269
+ channels % num_head_channels == 0
270
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
271
+ self.num_heads = channels // num_head_channels
272
+ self.use_checkpoint = use_checkpoint
273
+ self.norm = normalization(channels)
274
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
275
+ if use_new_attention_order:
276
+ # split qkv before split heads
277
+ self.attention = QKVAttention(self.num_heads)
278
+ else:
279
+ # split heads before split qkv
280
+ self.attention = QKVAttentionLegacy(self.num_heads)
281
+
282
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
283
+
284
+ @checkpoint_wrapper
285
+ def forward(self, x):
286
+ b, c, *spatial = x.shape
287
+ x = x.reshape(b, c, -1)
288
+ qkv = self.qkv(self.norm(x))
289
+ h = self.attention(qkv)
290
+ h = self.proj_out(h)
291
+ return (x + h).reshape(b, c, *spatial)
292
+
293
+
294
+ def count_flops_attn(model, _x, y):
295
+ """
296
+ A counter for the `thop` package to count the operations in an
297
+ attention operation.
298
+ Meant to be used like:
299
+ macs, params = thop.profile(
300
+ model,
301
+ inputs=(inputs, timestamps),
302
+ custom_ops={QKVAttention: QKVAttention.count_flops},
303
+ )
304
+ """
305
+ b, c, *spatial = y[0].shape
306
+ num_spatial = int(np.prod(spatial))
307
+ # We perform two matmuls with the same number of ops.
308
+ # The first computes the weight matrix, the second computes
309
+ # the combination of the value vectors.
310
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
311
+ model.total_ops += th.DoubleTensor([matmul_ops])
312
+
313
+
314
+ class QKVAttentionLegacy(nn.Module):
315
+ """
316
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
317
+ """
318
+
319
+ def __init__(self, n_heads):
320
+ super().__init__()
321
+ self.n_heads = n_heads
322
+
323
+ def forward(self, qkv):
324
+ """
325
+ Apply QKV attention.
326
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
327
+ :return: an [N x (H * C) x T] tensor after attention.
328
+ """
329
+ bs, width, length = qkv.shape
330
+ assert width % (3 * self.n_heads) == 0
331
+ ch = width // (3 * self.n_heads)
332
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
333
+ scale = 1 / math.sqrt(math.sqrt(ch))
334
+ weight = th.einsum(
335
+ "bct,bcs->bts", q * scale, k * scale
336
+ ) # More stable with f16 than dividing afterwards
337
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
338
+ a = th.einsum("bts,bcs->bct", weight, v)
339
+ return a.reshape(bs, -1, length)
340
+
341
+ @staticmethod
342
+ def count_flops(model, _x, y):
343
+ return count_flops_attn(model, _x, y)
344
+
345
+
346
+ class QKVAttention(nn.Module):
347
+ """
348
+ A module which performs QKV attention and splits in a different order.
349
+ """
350
+
351
+ def __init__(self, n_heads):
352
+ super().__init__()
353
+ self.n_heads = n_heads
354
+
355
+ def forward(self, qkv):
356
+ """
357
+ Apply QKV attention.
358
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
359
+ :return: an [N x (H * C) x T] tensor after attention.
360
+ """
361
+ bs, width, length = qkv.shape
362
+ assert width % (3 * self.n_heads) == 0
363
+ ch = width // (3 * self.n_heads)
364
+ q, k, v = qkv.chunk(3, dim=1)
365
+ scale = 1 / math.sqrt(math.sqrt(ch))
366
+ weight = th.einsum(
367
+ "bct,bcs->bts",
368
+ (q * scale).view(bs * self.n_heads, ch, length),
369
+ (k * scale).view(bs * self.n_heads, ch, length),
370
+ ) # More stable with f16 than dividing afterwards
371
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
372
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
373
+ return a.reshape(bs, -1, length)
374
+
375
+ @staticmethod
376
+ def count_flops(model, _x, y):
377
+ return count_flops_attn(model, _x, y)
378
+
379
+
380
+ class Timestep(nn.Module):
381
+ def __init__(self, dim):
382
+ super().__init__()
383
+ self.dim = dim
384
+
385
+ def forward(self, t):
386
+ return timestep_embedding(t, self.dim)
refnet/ldm/util.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from einops import repeat
16
+
17
+
18
+ def rescale_zero_terminal_snr(betas):
19
+ """
20
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
21
+
22
+
23
+ Args:
24
+ betas (`torch.FloatTensor`):
25
+ the betas that the scheduler is being initialized with.
26
+
27
+ Returns:
28
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
29
+ """
30
+ # Convert betas to alphas_bar_sqrt
31
+ alphas = 1.0 - betas
32
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
33
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
34
+
35
+ # Store old values.
36
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
37
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
38
+
39
+ # Shift so the last timestep is zero.
40
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
41
+
42
+ # Scale so the first timestep is back to the old value.
43
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
44
+
45
+ # Convert alphas_bar_sqrt to betas
46
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
47
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
48
+ alphas = torch.cat([alphas_bar[0:1], alphas])
49
+ betas = 1 - alphas
50
+
51
+ return betas
52
+
53
+
54
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
55
+ if schedule == "linear":
56
+ betas = (
57
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
58
+ )
59
+ elif schedule == "scaled_linear":
60
+ betas = (
61
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
62
+ )
63
+
64
+ elif schedule == "cosine":
65
+ timesteps = (
66
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
67
+ )
68
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
69
+ alphas = torch.cos(alphas).pow(2)
70
+ alphas = alphas / alphas[0]
71
+ betas = 1 - alphas[1:] / alphas[:-1]
72
+ betas = np.clip(betas, a_min=0, a_max=0.999)
73
+
74
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
75
+ # return early
76
+ return betas_for_alpha_bar(
77
+ n_timestep,
78
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
79
+ )
80
+
81
+ elif schedule == "sqrt_linear":
82
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
83
+ elif schedule == "sqrt":
84
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
85
+ else:
86
+ raise ValueError(f"schedule '{schedule}' unknown.")
87
+
88
+ if zero_snr:
89
+ betas = rescale_zero_terminal_snr(betas)
90
+ return betas.numpy()
91
+
92
+
93
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
94
+ if ddim_discr_method == 'uniform':
95
+ c = num_ddpm_timesteps // num_ddim_timesteps
96
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
97
+ elif ddim_discr_method == 'quad':
98
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
99
+ else:
100
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
101
+
102
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
103
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
104
+ steps_out = ddim_timesteps + 1
105
+ if verbose:
106
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
107
+ return steps_out
108
+
109
+
110
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
111
+ # select alphas for computing the variance schedule
112
+ alphas = alphacums[ddim_timesteps]
113
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
114
+
115
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
116
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
117
+ if verbose:
118
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
119
+ print(f'For the chosen value of eta, which is {eta}, '
120
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
121
+ return sigmas, alphas, alphas_prev
122
+
123
+
124
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
125
+ """
126
+ Create a beta schedule that discretizes the given alpha_t_bar function,
127
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
128
+ :param num_diffusion_timesteps: the number of betas to produce.
129
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
130
+ produces the cumulative product of (1-beta) up to that
131
+ part of the diffusion process.
132
+ :param max_beta: the maximum beta to use; use values lower than 1 to
133
+ prevent singularities.
134
+ """
135
+ betas = []
136
+ for i in range(num_diffusion_timesteps):
137
+ t1 = i / num_diffusion_timesteps
138
+ t2 = (i + 1) / num_diffusion_timesteps
139
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
140
+ return np.array(betas)
141
+
142
+
143
+ def extract_into_tensor(a, t, x_shape):
144
+ b, *_ = t.shape
145
+ out = a.gather(-1, t)
146
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
147
+
148
+
149
+
150
+ class CheckpointFunction(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(ctx, run_function, length, *args):
153
+ ctx.run_function = run_function
154
+ ctx.input_tensors = list(args[:length])
155
+ ctx.input_params = list(args[length:])
156
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
157
+ "dtype": torch.get_autocast_gpu_dtype(),
158
+ "cache_enabled": torch.is_autocast_cache_enabled()}
159
+ with torch.no_grad():
160
+ output_tensors = ctx.run_function(*ctx.input_tensors)
161
+ return output_tensors
162
+
163
+ @staticmethod
164
+ def backward(ctx, *output_grads):
165
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
166
+ with torch.enable_grad(), \
167
+ torch.amp.autocast("cuda", **ctx.gpu_autocast_kwargs):
168
+ # Fixes a bug where the first op in run_function modifies the
169
+ # Tensor storage in place, which is not allowed for detach()'d
170
+ # Tensors.
171
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
172
+ output_tensors = ctx.run_function(*shallow_copies)
173
+ input_grads = torch.autograd.grad(
174
+ output_tensors,
175
+ ctx.input_tensors + ctx.input_params,
176
+ output_grads,
177
+ allow_unused=True,
178
+ )
179
+ del ctx.input_tensors
180
+ del ctx.input_params
181
+ del output_tensors
182
+ return (None, None) + input_grads
183
+
184
+
185
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
186
+ """
187
+ Create sinusoidal timestep embeddings.
188
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
189
+ These may be fractional.
190
+ :param dim: the dimension of the output.
191
+ :param max_period: controls the minimum frequency of the embeddings.
192
+ :return: an [N x dim] Tensor of positional embeddings.
193
+ """
194
+ if not repeat_only:
195
+ half = dim // 2
196
+ freqs = torch.exp(
197
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
198
+ ).to(device=timesteps.device)
199
+ args = timesteps[:, None].float() * freqs[None]
200
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
201
+ if dim % 2:
202
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
203
+ else:
204
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
205
+ return embedding
206
+
207
+
208
+ def zero_module(module):
209
+ """
210
+ Zero out the parameters of a module and return it.
211
+ """
212
+ for p in module.parameters():
213
+ p.detach().zero_()
214
+ return module
215
+
216
+
217
+ def scale_module(module, scale):
218
+ """
219
+ Scale the parameters of a module and return it.
220
+ """
221
+ for p in module.parameters():
222
+ p.detach().mul_(scale)
223
+ return module
224
+
225
+
226
+ def mean_flat(tensor):
227
+ """
228
+ Take the mean over all non-batch dimensions.
229
+ """
230
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
231
+
232
+
233
+ def normalization(channels):
234
+ """
235
+ Make a standard normalization layer.
236
+ :param channels: number of input channels.
237
+ :return: an nn.Module for normalization.
238
+ """
239
+ return GroupNorm32(32, channels)
240
+
241
+
242
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
243
+ class SiLU(nn.Module):
244
+ def forward(self, x):
245
+ return x * torch.sigmoid(x)
246
+
247
+
248
+ class GroupNorm32(nn.GroupNorm):
249
+ def forward(self, x):
250
+ return super().forward(x.to(self.weight.dtype)).type(x.dtype)
251
+
252
+
253
+ def conv_nd(dims, *args, **kwargs):
254
+ """
255
+ Create a 1D, 2D, or 3D convolution module.
256
+ """
257
+ if dims == 1:
258
+ return nn.Conv1d(*args, **kwargs)
259
+ elif dims == 2:
260
+ return nn.Conv2d(*args, **kwargs)
261
+ elif dims == 3:
262
+ return nn.Conv3d(*args, **kwargs)
263
+ raise ValueError(f"unsupported dimensions: {dims}")
264
+
265
+
266
+ def linear(*args, **kwargs):
267
+ """
268
+ Create a linear module.
269
+ """
270
+ return nn.Linear(*args, **kwargs)
271
+
272
+
273
+ def avg_pool_nd(dims, *args, **kwargs):
274
+ """
275
+ Create a 1D, 2D, or 3D average pooling module.
276
+ """
277
+ if dims == 1:
278
+ return nn.AvgPool1d(*args, **kwargs)
279
+ elif dims == 2:
280
+ return nn.AvgPool2d(*args, **kwargs)
281
+ elif dims == 3:
282
+ return nn.AvgPool3d(*args, **kwargs)
283
+ raise ValueError(f"unsupported dimensions: {dims}")
284
+
285
+
286
+ def noise_like(shape, device, repeat=False):
287
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
288
+ noise = lambda: torch.randn(shape, device=device)
289
+ return repeat_noise() if repeat else noise()
refnet/modules/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+
4
+ def wd_v14_swin2_tagger_config():
5
+ CustomConfig = namedtuple('CustomConfig', [
6
+ 'architecture', 'num_classes', 'num_features', 'global_pool', 'model_args', 'pretrained_cfg'
7
+ ])
8
+
9
+ custom_config = CustomConfig(
10
+ architecture="swinv2_base_window8_256",
11
+ num_classes=9083,
12
+ num_features=1024,
13
+ global_pool="avg",
14
+ model_args={
15
+ "act_layer": "gelu",
16
+ "img_size": 448,
17
+ "window_size": 14
18
+ },
19
+ pretrained_cfg={
20
+ "custom_load": False,
21
+ "input_size": [3, 448, 448],
22
+ "fixed_input_size": False,
23
+ "interpolation": "bicubic",
24
+ "crop_pct": 1.0,
25
+ "crop_mode": "center",
26
+ "mean": [0.5, 0.5, 0.5],
27
+ "std": [0.5, 0.5, 0.5],
28
+ "num_classes": 9083,
29
+ "pool_size": None,
30
+ "first_conv": None,
31
+ "classifier": None
32
+ }
33
+ )
34
+ return custom_config
refnet/modules/attention.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from calendar import c
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange
5
+ from refnet.util import exists, default, checkpoint_wrapper
6
+ from .layers import RMSNorm
7
+ from .attn_utils import *
8
+
9
+
10
+ def create_masked_attention_bias(
11
+ mask: torch.Tensor,
12
+ threshold: float,
13
+ num_heads: int,
14
+ context_len: int
15
+ ):
16
+ b, seq_len, _ = mask.shape
17
+ half_len = context_len // 2
18
+
19
+ if context_len % 8 != 0:
20
+ padded_context_len = ((context_len + 7) // 8) * 8
21
+ else:
22
+ padded_context_len = context_len
23
+
24
+ fg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
25
+ bg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
26
+
27
+ fg_bias[:, :, half_len:] = -float('inf')
28
+ bg_bias[:, :, :half_len] = -float('inf')
29
+ attn_bias = torch.where(mask > threshold, fg_bias, bg_bias)
30
+ return attn_bias.unsqueeze(1).repeat_interleave(num_heads, dim=1)
31
+
32
+ class Identity(nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def forward(self, x, *args, **kwargs):
37
+ return x
38
+
39
+
40
+ # Rotary Positional Embeddings implementation
41
+ class RotaryPositionalEmbeddings(nn.Module):
42
+ def __init__(self, dim, max_seq_len=1024, theta=10000.0):
43
+ super().__init__()
44
+ assert dim % 2 == 0, "Dimension must be divisible by 2"
45
+ dim = dim // 2
46
+ self.max_seq_len = max_seq_len
47
+ freqs = torch.outer(
48
+ torch.arange(self.max_seq_len),
49
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
50
+ )
51
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
52
+ self.register_buffer("freq_h", freqs, persistent=False)
53
+ self.register_buffer("freq_w", freqs, persistent=False)
54
+
55
+ def forward(self, x, grid_size):
56
+ bs, seq_len, heads = x.shape[:3]
57
+ h, w = grid_size
58
+
59
+ x_complex = torch.view_as_complex(
60
+ x.float().reshape(bs, seq_len, heads, -1, 2)
61
+ )
62
+ freqs = torch.cat([
63
+ self.freq_h[:h].view(1, h, 1, -1).expand(bs, h, w, -1),
64
+ self.freq_w[:w].view(1, 1, w, -1).expand(bs, h, w, -1)
65
+ ], dim=-1).reshape(bs, seq_len, 1, -1)
66
+
67
+ x_out = x_complex * freqs
68
+ x_out = torch.view_as_real(x_out).flatten(3)
69
+
70
+ return x_out.type_as(x)
71
+
72
+
73
+ class MemoryEfficientAttention(nn.Module):
74
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
75
+ def __init__(
76
+ self,
77
+ query_dim,
78
+ context_dim = None,
79
+ heads = None,
80
+ dim_head = 64,
81
+ dropout = 0.0,
82
+ log = False,
83
+ causal = False,
84
+ rope = False,
85
+ max_seq_len = 1024,
86
+ qk_norm = False,
87
+ **kwargs
88
+ ):
89
+ super().__init__()
90
+ if log:
91
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
92
+ f"{heads} heads.")
93
+
94
+ heads = heads or query_dim // dim_head
95
+ inner_dim = dim_head * heads
96
+ context_dim = default(context_dim, query_dim)
97
+
98
+ self.heads = heads
99
+ self.dim_head = dim_head
100
+
101
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
102
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
103
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
104
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
105
+
106
+ self.q_norm = RMSNorm(inner_dim) if qk_norm else Identity()
107
+ self.k_norm = RMSNorm(inner_dim) if qk_norm else Identity()
108
+ self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_len=max_seq_len) if rope else Identity()
109
+ self.attn_ops = causal_ops if causal else {}
110
+
111
+ # default setting for split cross-attention
112
+ self.bg_scale = 1.
113
+ self.fg_scale = 1.
114
+ self.merge_scale = 0.
115
+ self.mask_threshold = 0.05
116
+
117
+ @checkpoint_wrapper
118
+ def forward(
119
+ self,
120
+ x,
121
+ context=None,
122
+ mask=None,
123
+ scale=1.,
124
+ scale_factor=None,
125
+ grid_size=None,
126
+ **kwargs,
127
+ ):
128
+ context = default(context, x)
129
+
130
+ if exists(mask):
131
+ out = self.masked_forward(x, context, mask, scale, scale_factor)
132
+ else:
133
+ q = self.to_q(x)
134
+ k = self.to_k(context)
135
+ v = self.to_v(context)
136
+ out = self.attn_forward(q, k, v, scale, grid_size)
137
+
138
+ return self.to_out(out)
139
+
140
+ def attn_forward(self, q, k, v, scale=1., grid_size=None, mask=None):
141
+ q, k = map(
142
+ lambda t:
143
+ self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
144
+ (self.q_norm(q), self.k_norm(k))
145
+ )
146
+ v = rearrange(v, "b n (h c) -> b n h c", h=self.heads)
147
+ out = attn_processor(q, k, v, attn_mask=mask, **self.attn_ops) * scale
148
+ out = rearrange(out, "b n h c -> b n (h c)")
149
+ return out
150
+
151
+ def masked_forward(self, x, context, mask, scale=1., scale_factor=None):
152
+ # split cross-attention function
153
+ def qkv_forward(x, context):
154
+ q = self.to_q(x)
155
+ k = self.to_k(context)
156
+ v = self.to_v(context)
157
+ return q, k, v
158
+
159
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
160
+ mask = rearrange(
161
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
162
+ "b c h w -> b (h w) c"
163
+ ).contiguous()
164
+
165
+ if self.merge_scale > 0:
166
+ # split cross-attention with merging scale, need two times forward
167
+ c1, c2 = context.chunk(2, dim=1)
168
+
169
+ # Background region cross-attention
170
+ q2, k2, v2 = qkv_forward(x, c2)
171
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
172
+
173
+ # Foreground region cross-attention
174
+ q1, k1, v1 = qkv_forward(x, c1)
175
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
176
+
177
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
178
+ return torch.where(mask < self.mask_threshold, bg_out, fg_out)
179
+
180
+ else:
181
+ attn_mask = create_masked_attention_bias(
182
+ mask, self.mask_threshold, self.heads, context.size(1)
183
+ )
184
+ q, k, v = qkv_forward(x, context)
185
+ return self.attn_forward(q, k, v, mask=attn_mask) * scale
186
+
187
+
188
+ class MultiModalAttention(MemoryEfficientAttention):
189
+ def __init__(self, query_dim, context_dim_2, heads=8, dim_head=64, qk_norm=False, *args, **kwargs):
190
+ super().__init__(query_dim, heads=heads, dim_head=dim_head, qk_norm=qk_norm, *args, **kwargs)
191
+ inner_dim = dim_head * heads
192
+ self.to_k_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
193
+ self.to_v_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
194
+ self.k2_norm = RMSNorm(inner_dim) if qk_norm else Identity()
195
+
196
+ def forward(self, x, context=None, mask=None, scale=1., grid_size=None):
197
+ if not isinstance(scale, list) and not isinstance(scale, tuple):
198
+ scale = (scale, scale)
199
+ assert len(context.shape) == 4, "Multi-modal attention requires different context inputs to be (b, m, n c)"
200
+ context, context2 = context.chunk(2, dim=1)
201
+
202
+ q = self.to_q(x)
203
+ k = self.to_k(context)
204
+ v = self.to_v(context)
205
+ k2 = self.to_k_2(context2)
206
+ v2 = self.to_k_2(context2)
207
+
208
+ b, _, _ = q.shape
209
+ q, k, k2 = map(
210
+ lambda t: self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
211
+ (self.q_norm(q), self.k_norm(k), self.k2_norm(k2))
212
+ )
213
+ v, v2 = map(lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), (v, v2))
214
+
215
+ out = (attn_processor(q, k, v, **self.attn_ops) * scale[0] +
216
+ attn_processor(q, k2, v2, **self.attn_ops) * scale[1])
217
+
218
+ if exists(mask):
219
+ raise NotImplementedError
220
+ out = rearrange(out, "b n h c -> b n (h c)")
221
+ return self.to_out(out)
222
+
223
+
224
+ class MultiScaleCausalAttention(MemoryEfficientAttention):
225
+ def forward(
226
+ self,
227
+ x,
228
+ context=None,
229
+ mask=None,
230
+ scale=1.,
231
+ scale_factor=None,
232
+ grid_size=None,
233
+ token_lens=None
234
+ ):
235
+ context = default(context, x)
236
+ q = self.to_q(x)
237
+ k = self.to_k(context)
238
+ v = self.to_v(context)
239
+ out = self.attn_forward(q, k, v, scale, grid_size=grid_size, token_lens=token_lens)
240
+ return self.to_out(out)
241
+
242
+ def attn_forward(self, q, k, v, scale = 1., grid_size = None, token_lens = None):
243
+ q, k, v = map(
244
+ lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads),
245
+ (self.q_norm(q), self.k_norm(k), v)
246
+ )
247
+
248
+ attn_output = []
249
+ prev_idx = 0
250
+ for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
251
+ end_idx = prev_idx + length + (idx == 0)
252
+ rope_prev_idx = prev_idx + (idx == 0)
253
+ rope_slice = slice(rope_prev_idx, end_idx)
254
+
255
+ q[:, rope_slice] = self.rope(q[:, rope_slice], grid)
256
+ k[:, rope_slice] = self.rope(k[:, rope_slice], grid)
257
+ qs = q[:, prev_idx: end_idx]
258
+ ks, vs = map(lambda t: t[:, :end_idx], (k, v))
259
+
260
+ attn_output.append(attn_processor(qs.clone(), ks.clone(), vs.clone()) * scale)
261
+ prev_idx = end_idx
262
+ attn_output = rearrange(torch.cat(attn_output, 1), "b n h c -> b n (h c)")
263
+ return attn_output
264
+
265
+ # if FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_AVAILABLE:
266
+ # k_chunks = []
267
+ # v_chunks = []
268
+ # kv_token_lens = []
269
+ # prev_idx = 0
270
+ # for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
271
+ # end_idx = prev_idx + length + (idx == 0)
272
+ # rope_prev_idx = prev_idx + (idx == 0)
273
+
274
+ # rope_slice = slice(rope_prev_idx, end_idx)
275
+ # q[:, rope_slice], k[:, rope_slice], v[:, rope_slice] = map(
276
+ # lambda t: self.rope(t[:, rope_slice], grid),
277
+ # (q, k, v)
278
+ # )
279
+ # kv_token_lens.append(end_idx+1)
280
+ # k_chunks.append(k[:, :end_idx])
281
+ # v_chunks.append(v[:, :end_idx])
282
+ # prev_idx = end_idx
283
+ # k = torch.cat(k_chunks, 1)
284
+ # v = torch.cat(v_chunks, 1)
285
+ # B, N, H, C = q.shape
286
+ # token_lens = torch.tensor(token_lens, device=q.device, dtype=torch.int32)
287
+ # kv_token_lens = torch.tensor(kv_token_lens, device=q.device, dtype=torch.int32)
288
+ # token_lens[0] = token_lens[0] + 1
289
+ #
290
+ # cu_seqlens_q, cu_seqlens_kv = map(lambda t:
291
+ # torch.cat([t.new_zeros([1]), t]).cumsum(0, dtype=torch.int32),
292
+ # (token_lens, kv_token_lens)
293
+ # )
294
+ # max_seqlen_q, max_seqlen_kv = map(lambda t: int(t.max()), (token_lens, kv_token_lens))
295
+ #
296
+ # q_flat = q.reshape(-1, H, C).contiguous()
297
+ # k_flat = k.reshape(-1, H, C).contiguous()
298
+ # v_flat = v.reshape(-1, H, C).contiguous()
299
+ # out_flat = flash_attn_varlen_func(
300
+ # q=q_flat, k=k_flat, v=v_flat,
301
+ # cu_seqlens_q=cu_seqlens_q,
302
+ # cu_seqlens_k=cu_seqlens_kv,
303
+ # max_seqlen_q=max_seqlen_q,
304
+ # max_seqlen_k=max_seqlen_kv,
305
+ # causal=True,
306
+ # )
307
+ #
308
+ # out = rearrange(out_flat, "(b n) h c -> b n (h c)", b=B, n=N)
309
+ # return out * scale
refnet/modules/attn_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ ATTN_PRECISION = torch.float16
5
+
6
+ try:
7
+ import flash_attn_interface
8
+ FLASH_ATTN_3_AVAILABLE = True
9
+ FLASH_ATTN_AVAILABLE = False
10
+
11
+ except ModuleNotFoundError:
12
+ FLASH_ATTN_3_AVAILABLE = False
13
+ try:
14
+ import flash_attn
15
+ FLASH_ATTN_AVAILABLE = True
16
+ except ModuleNotFoundError:
17
+ FLASH_ATTN_AVAILABLE = False
18
+
19
+ try:
20
+ import xformers.ops
21
+ XFORMERS_IS_AVAILBLE = True
22
+ except:
23
+ XFORMERS_IS_AVAILBLE = False
24
+
25
+
26
+ def half(x):
27
+ if x.dtype not in [torch.float16, torch.bfloat16]:
28
+ x = x.to(ATTN_PRECISION)
29
+ return x
30
+
31
+ def attn_processor(q, k, v, attn_mask = None, *args, **kwargs):
32
+ if attn_mask is not None:
33
+ if XFORMERS_IS_AVAILBLE:
34
+ out = xformers.ops.memory_efficient_attention(
35
+ q, k, v, attn_bias=attn_mask, *args, **kwargs
36
+ )
37
+ else:
38
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
39
+ out = F.scaled_dot_product_attention(
40
+ q, k, v, attn_mask=attn_mask, *args, **kwargs
41
+ ).transpose(1, 2)
42
+ else:
43
+ if FLASH_ATTN_3_AVAILABLE:
44
+ dtype = v.dtype
45
+ q, k, v = map(lambda t: half(t), (q, k, v))
46
+ out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype)
47
+ elif FLASH_ATTN_AVAILABLE:
48
+ dtype = v.dtype
49
+ q, k, v = map(lambda t: half(t), (q, k, v))
50
+ out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype)
51
+ elif XFORMERS_IS_AVAILBLE:
52
+ out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs)
53
+ else:
54
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
55
+ out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2)
56
+ return out
57
+
58
+
59
+ def flash_attn_varlen_func(q, k, v, **kwargs):
60
+ if FLASH_ATTN_3_AVAILABLE:
61
+ return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0]
62
+ else:
63
+ return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs)
64
+
65
+
66
+ def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5):
67
+ """
68
+ Split input tensor into foreground and background based on mask, then concatenate them.
69
+
70
+ Args:
71
+ tensor: Input tensor of shape (batch, seq_len, dim)
72
+ mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len)
73
+ threshold: Threshold for mask binarization
74
+
75
+ Returns:
76
+ split_tensor: Concatenated tensor with foreground first, then background
77
+ fg_indices: Indices of foreground elements for restoration
78
+ bg_indices: Indices of background elements for restoration
79
+ original_shape: Original tensor shape for restoration
80
+ """
81
+ batch_size, seq_len, *dims = tensor.shape
82
+ device, dtype = tensor.device, tensor.dtype
83
+
84
+ # Ensure mask has correct shape and binarize
85
+ if mask.dim() == 2:
86
+ mask = mask.unsqueeze(-1)
87
+ binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len)
88
+
89
+ # Store indices for restoration (keep minimal loop for complex indexing)
90
+ fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)]
91
+ bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)]
92
+
93
+ # Count elements efficiently
94
+ fg_counts = binary_mask.sum(dim=1)
95
+ bg_counts = (~binary_mask).sum(dim=1)
96
+ max_fg_len = fg_counts.max().item()
97
+ max_bg_len = bg_counts.max().item()
98
+
99
+ # Early exit if no elements
100
+ if max_fg_len == 0 and max_bg_len == 0:
101
+ return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape
102
+
103
+ # Create output tensor
104
+ split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype)
105
+
106
+ # Vectorized approach using gather for better efficiency
107
+ for b in range(batch_size):
108
+ if len(fg_indices[b]) > 0:
109
+ split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]]
110
+ if len(bg_indices[b]) > 0:
111
+ split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]]
112
+
113
+ return split_tensor, fg_indices, bg_indices, tensor.shape
114
+
115
+
116
+ def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list,
117
+ original_shape: torch.Size):
118
+ """
119
+ Restore original tensor from split tensor using stored indices.
120
+
121
+ Args:
122
+ split_tensor: Split tensor from split_tensor_by_mask
123
+ fg_indices: List of foreground indices for each batch
124
+ bg_indices: List of background indices for each batch
125
+ original_shape: Original tensor shape
126
+
127
+ Returns:
128
+ restored_tensor: Restored tensor with original shape and ordering
129
+ """
130
+ batch_size, seq_len = original_shape[:2]
131
+ dims = original_shape[2:]
132
+ device, dtype = split_tensor.device, split_tensor.dtype
133
+
134
+ # Calculate split point efficiently
135
+ max_fg_len = max((len(fg) for fg in fg_indices), default=0)
136
+
137
+ # Initialize restored tensor
138
+ restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype)
139
+
140
+ # Early exit if no elements to restore
141
+ if split_tensor.shape[1] == 0:
142
+ return restored_tensor
143
+
144
+ # Split tensor parts
145
+ fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None
146
+ bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None
147
+
148
+ # Restore in single loop with efficient indexing
149
+ for b in range(batch_size):
150
+ if fg_part is not None and len(fg_indices[b]) > 0:
151
+ restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])]
152
+ if bg_part is not None and len(bg_indices[b]) > 0:
153
+ restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])]
154
+
155
+ return restored_tensor
refnet/modules/embedder.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+
5
+ from tqdm import tqdm
6
+ from einops import rearrange
7
+ from refnet.util import exists, append_dims
8
+ from refnet.sampling import tps_warp
9
+ from refnet.ldm.openaimodel import Timestep, zero_module
10
+
11
+ import timm
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.transforms
15
+ import torch.nn.functional as F
16
+
17
+ from huggingface_hub import hf_hub_download
18
+ from torch.utils.checkpoint import checkpoint
19
+ from safetensors.torch import load_file
20
+ from transformers import (
21
+ T5EncoderModel,
22
+ T5Tokenizer,
23
+ CLIPVisionModelWithProjection,
24
+ CLIPTextModel,
25
+ CLIPTokenizer,
26
+ )
27
+
28
+ versions = {
29
+ "ViT-bigG-14": "laion2b_s39b_b160k",
30
+ "ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32
31
+ "ViT-L-14": "laion2b_s32b_b82k",
32
+ "hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H]
33
+ }
34
+ hf_versions = {
35
+ "ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
36
+ "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
37
+ "ViT-L-14": "openai/clip-vit-large-patch14",
38
+ }
39
+ cache_dir = os.environ.get("HF_HOME", "./pretrained_models")
40
+
41
+
42
+ class WDv14SwinTransformerV2(nn.Module):
43
+ """
44
+ WD-v14-tagger
45
+ Author: Smiling Wolf
46
+ Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2
47
+ """
48
+ negative_logit = -22
49
+
50
+ def __init__(
51
+ self,
52
+ input_size = 448,
53
+ antialias = True,
54
+ layer_idx = 0.,
55
+ load_tag = False,
56
+ logit_threshold = None,
57
+ direct_forward = False,
58
+ ):
59
+ """
60
+
61
+ Args:
62
+ input_size: Input image size
63
+ antialias: Antialias during rescaling
64
+ layer_idx: Extracted feature layer
65
+ load_tag: Set it to true if use the embedder for image classification
66
+ logit_threshold: Filtering specific channels in logits output
67
+ """
68
+ from refnet.modules import wd_v14_swin2_tagger_config
69
+ super().__init__()
70
+ custom_config = wd_v14_swin2_tagger_config()
71
+ self.model: nn.Module = timm.create_model(
72
+ custom_config.architecture,
73
+ pretrained = False,
74
+ num_classes = custom_config.num_classes,
75
+ global_pool = custom_config.global_pool,
76
+ **custom_config.model_args
77
+ )
78
+ self.image_size = input_size
79
+ self.antialias = antialias
80
+ self.layer_idx = layer_idx
81
+ self.load_tag = load_tag
82
+ self.logit_threshold = logit_threshold
83
+ self.direct_forward = direct_forward
84
+
85
+ self.load_from_pretrained_url(load_tag)
86
+ self.get_transformer_length()
87
+ self.model.eval()
88
+ self.model.requires_grad_(False)
89
+
90
+ if self.direct_forward:
91
+ self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__)
92
+
93
+
94
+ def load_from_pretrained_url(self, load_tag=False):
95
+ import pandas as pd
96
+ from torch.hub import download_url_to_file
97
+ from data.tag_utils import load_labels, color_tag_index, geometry_tag_index
98
+
99
+ ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors")
100
+ if not os.path.exists(ckpt_path):
101
+ cache_path = os.path.join(cache_dir, "weights.tmp")
102
+ download_url_to_file(
103
+ "https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors",
104
+ dst = cache_path
105
+ )
106
+ os.rename(cache_path, ckpt_path)
107
+
108
+ if load_tag:
109
+ csv_path = hf_hub_download(
110
+ "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
111
+ "selected_tags.csv",
112
+ cache_dir = cache_dir
113
+ # use_auth_token=HF_TOKEN,
114
+ )
115
+ tags_df = pd.read_csv(csv_path)
116
+ sep_tags = load_labels(tags_df)
117
+
118
+ self.tag_names = sep_tags[0]
119
+ self.rating_indexes = sep_tags[1]
120
+ self.general_indexes = sep_tags[2]
121
+ self.character_indexes = sep_tags[3]
122
+
123
+ self.color_tags = color_tag_index
124
+ self.expr_tags = geometry_tag_index
125
+ self.model.load_state_dict(load_file(ckpt_path))
126
+
127
+
128
+ def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85):
129
+ assert self.load_tag
130
+ labels = list(zip(self.tag_names, pred[0].astype(float)))
131
+
132
+ # First 4 labels are actually ratings: pick one with argmax
133
+ # ratings_names = [labels[i] for i in self.rating_indexes]
134
+ # rating = dict(ratings_names)
135
+
136
+ # Then we have general tags: pick any where prediction confidence > threshold
137
+ general_names = [labels[i] for i in self.general_indexes]
138
+
139
+ general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh]
140
+ general_res = dict(general_res)
141
+
142
+ # Everything else is characters: pick any where prediction confidence > threshold
143
+ character_names = [labels[i] for i in self.character_indexes]
144
+
145
+ character_res = [x for x in character_names if x[1] > character_thresh]
146
+ character_res = dict(character_res)
147
+
148
+ sorted_general_strings = sorted(
149
+ general_res.items(),
150
+ key=lambda x: x[1],
151
+ reverse=True,
152
+ )
153
+
154
+ sorted_general_res = sorted(
155
+ general_res.items(),
156
+ key=lambda x: x[1],
157
+ reverse=True,
158
+ )
159
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
160
+ sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
161
+
162
+ # return sorted_general_strings, rating, character_res, general_res
163
+ return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res
164
+
165
+ def get_transformer_length(self):
166
+ length = 0
167
+ for stage in self.model.layers:
168
+ length += len(stage.blocks)
169
+ self.transformer_length = length
170
+
171
+ def transformer_forward(self, x):
172
+ idx = 0
173
+ x = self.model.patch_embed(x)
174
+ for stage in self.model.layers:
175
+ x = stage.downsample(x)
176
+ for blk in stage.blocks:
177
+ if idx == self.transformer_length - self.layer_idx:
178
+ return x
179
+ if not torch.jit.is_scripting():
180
+ x = checkpoint(blk, x, use_reentrant=False)
181
+ else:
182
+ x = blk(x)
183
+ idx += 1
184
+ return x
185
+
186
+
187
+ def forward(self, x, return_logits=False, pooled=True, **kwargs):
188
+ # x: [b, h, w, 3]
189
+ if self.direct_forward:
190
+ x = self.model(x)
191
+ else:
192
+ x = self.transformer_forward(x)
193
+ x = self.model.norm(x)
194
+
195
+ # x: [b, 14, 14, 1024]
196
+ if return_logits:
197
+ if pooled:
198
+ logits = self.model.forward_head(x).unsqueeze(1)
199
+ # x: [b, 1, 1024]
200
+
201
+ else:
202
+ logits = self.model.head.fc(x)
203
+ # x = F.sigmoid(x)
204
+ logits = rearrange(logits, "b h w c -> b (h w) c").contiguous()
205
+ # x: [b, 196, 9083]
206
+
207
+ # Need a threshold to cut off unnecessary classes.
208
+ if exists(self.logit_threshold) and isinstance(self.logit_threshold, float):
209
+ logits = torch.where(
210
+ logits > self.logit_threshold,
211
+ logits,
212
+ torch.ones_like(logits) * self.negative_logit
213
+ )
214
+
215
+ else:
216
+ logits = None
217
+
218
+ if pooled:
219
+ x = x.mean(dim=[1, 2]).unsqueeze(1)
220
+ else:
221
+ x = rearrange(x, "b h w c -> b (h w) c").contiguous()
222
+ return [x, logits]
223
+
224
+ def preprocess(self, x: torch.Tensor):
225
+ x = F.interpolate(
226
+ x,
227
+ (self.image_size, self.image_size),
228
+ mode = "bicubic",
229
+ align_corners = True,
230
+ antialias = self.antialias
231
+ )
232
+ # convert RGB to BGR
233
+ x = x[:, [2, 1, 0]]
234
+ return x
235
+
236
+ @torch.no_grad()
237
+ def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs):
238
+ # Input image must be in RGB format
239
+ return self(self.preprocess(img), return_logits, pooled)
240
+
241
+ @torch.no_grad()
242
+ def predict_labels(self, img: torch.Tensor, *args, **kwargs):
243
+ assert len(img.shape) == 4 and img.shape[0] == 1
244
+ logits = self(self.preprocess(img), return_logits=True, pooled=True)[1]
245
+ logits = F.sigmoid(logits).detach().cpu().numpy()
246
+ return self.convert_labels(logits, *args, **kwargs)
247
+
248
+ def geometry_update(self, emb, geometry_emb, scale_factor=1):
249
+ """
250
+
251
+ Args:
252
+ emb: WD embedding from reference image
253
+ geometry_emb: WD embedding from sketch image
254
+
255
+ """
256
+ geometry_mask = torch.zeros_like(emb)
257
+ geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels
258
+ emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor
259
+ return emb
260
+
261
+ @property
262
+ def dtype(self):
263
+ return self.model.head.fc.weight.dtype
264
+
265
+
266
+ class OpenCLIP(nn.Module):
267
+ def __init__(self, vision_config=None, text_config=None, **kwargs):
268
+ super().__init__()
269
+ if exists(vision_config):
270
+ vision_config.update(kwargs)
271
+ else:
272
+ vision_config = kwargs
273
+
274
+ if exists(text_config):
275
+ text_config.update(kwargs)
276
+ else:
277
+ text_config = kwargs
278
+
279
+ self.visual = FrozenOpenCLIPImageEmbedder(**vision_config)
280
+ self.transformer = FrozenOpenCLIPEmbedder(**text_config)
281
+
282
+ def preprocess(self, x):
283
+ return self.visual.preprocess(x)
284
+
285
+ @property
286
+ def scale_factor(self):
287
+ return self.visual.scale_factor
288
+
289
+ def update_scale_factor(self, scale_factor):
290
+ self.visual.update_scale_factor(scale_factor)
291
+
292
+ def encode(self, *args, **kwargs):
293
+ return self.visual.encode(*args, **kwargs)
294
+
295
+ @torch.no_grad()
296
+ def encode_text(self, text, normalize=True):
297
+ return self.transformer(text, normalize)
298
+
299
+ def calculate_scale(self, v: torch.Tensor, t: torch.Tensor):
300
+ """
301
+ Calculate the projection of v along the direction of t
302
+ params:
303
+ v: visual tokens from clip image encoder, shape: (b, n, c)
304
+ t: text features from clip text encoder (argmax -1), shape: (b, 1, c)
305
+ """
306
+ return v @ t.mT
307
+
308
+
309
+
310
+ class HFCLIPVisionModel(nn.Module):
311
+ # TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future
312
+ def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.):
313
+ super().__init__()
314
+ self.model = CLIPVisionModelWithProjection.from_pretrained(
315
+ hf_versions[arch],
316
+ cache_dir = cache_dir
317
+ )
318
+ self.image_size = image_size
319
+ self.scale_factor = scale_factor
320
+ self.register_buffer(
321
+ 'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False
322
+ )
323
+ self.register_buffer(
324
+ 'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False
325
+ )
326
+ self.antialias = True
327
+ self.requires_grad_(False).eval()
328
+
329
+ def preprocess(self, x):
330
+ # normalize to [0,1]
331
+ ns = int(self.image_size * self.scale_factor)
332
+ x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias)
333
+ x = (x + 1.0) / 2.0
334
+
335
+ # renormalize according to clip
336
+ x = (x - self.mean) / self.std
337
+ return x
338
+
339
+ def forward(self, x, output_type):
340
+ outputs = self.model(x).last_hidden_state
341
+ if output_type == "cls":
342
+ outputs = outputs[:, :1]
343
+ elif output_type == "local":
344
+ outputs = outputs[:, 1:]
345
+ outputs = self.model.vision_model.post_layernorm(outputs)
346
+ outputs = self.model.visual_projection(outputs)
347
+ return outputs
348
+
349
+ @torch.no_grad()
350
+ def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs):
351
+ img = self.preprocess(img) if preprocess else img
352
+
353
+ if warp_p > 0.:
354
+ rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim)
355
+ img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img))
356
+ return self(img, output_type)
357
+
358
+
359
+
360
+
361
+ class FrozenT5Embedder(nn.Module):
362
+ """Uses the T5 transformer encoder for text"""
363
+
364
+ def __init__(
365
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
366
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
367
+ super().__init__()
368
+ self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir)
369
+ self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir)
370
+ self.device = device
371
+ self.max_length = max_length
372
+ if freeze:
373
+ self.freeze()
374
+
375
+ def freeze(self):
376
+ self.transformer = self.transformer.eval()
377
+
378
+ for param in self.parameters():
379
+ param.requires_grad = False
380
+
381
+ def forward(self, text):
382
+ batch_encoding = self.tokenizer(
383
+ text,
384
+ truncation=True,
385
+ max_length=self.max_length,
386
+ return_length=True,
387
+ return_overflowing_tokens=False,
388
+ padding="max_length",
389
+ return_tensors="pt",
390
+ )
391
+ tokens = batch_encoding["input_ids"].to(self.device)
392
+ with torch.autocast("cuda", enabled=False):
393
+ outputs = self.transformer(input_ids=tokens)
394
+ z = outputs.last_hidden_state
395
+ return z
396
+
397
+ @torch.no_grad()
398
+ def encode(self, text):
399
+ return self(text)
400
+
401
+
402
+ class HFCLIPTextEmbedder(nn.Module):
403
+ def __init__(self, arch, freeze=True, device="cuda", max_length=77):
404
+ super().__init__()
405
+ self.tokenizer = CLIPTokenizer.from_pretrained(
406
+ hf_versions[arch],
407
+ cache_dir = cache_dir
408
+ )
409
+ self.model = CLIPTextModel.from_pretrained(
410
+ hf_versions[arch],
411
+ cache_dir = cache_dir
412
+ )
413
+ self.device = device
414
+ self.max_length = max_length
415
+ if freeze:
416
+ self.freeze()
417
+
418
+ def freeze(self):
419
+ self.model = self.model.eval()
420
+
421
+ for param in self.parameters():
422
+ param.requires_grad = False
423
+
424
+ def forward(self, text):
425
+ if isinstance(text, torch.Tensor) and text.dtype == torch.long:
426
+ # Input is already tokenized
427
+ tokens = text
428
+ else:
429
+ # Need to tokenize text input
430
+ batch_encoding = self.tokenizer(
431
+ text,
432
+ truncation=True,
433
+ max_length=self.max_length,
434
+ padding="max_length",
435
+ return_tensors="pt",
436
+ )
437
+ tokens = batch_encoding["input_ids"].to(self.device)
438
+
439
+ outputs = self.model(input_ids=tokens)
440
+ z = outputs.last_hidden_state
441
+ return z
442
+
443
+ @torch.no_grad()
444
+ def encode(self, text, normalize=False):
445
+ outputs = self(text)
446
+ if normalize:
447
+ outputs = outputs / outputs.norm(dim=-1, keepdim=True)
448
+ return outputs
449
+
450
+
451
+ class ScalarEmbedder(nn.Module):
452
+ """embeds each dimension independently and concatenates them"""
453
+
454
+ def __init__(self, embed_dim, out_dim):
455
+ super().__init__()
456
+ self.timestep = Timestep(embed_dim)
457
+ self.embed_layer = nn.Sequential(
458
+ nn.Linear(embed_dim, out_dim),
459
+ nn.SiLU(),
460
+ zero_module(nn.Linear(out_dim, out_features=out_dim))
461
+ )
462
+
463
+ def forward(self, x, dtype=torch.float32):
464
+ emb = self.timestep(x)
465
+ emb = rearrange(emb, "b d -> b 1 d")
466
+ emb = self.embed_layer(emb.to(dtype))
467
+ return emb
468
+
469
+
470
+ class TimestepEmbedding(nn.Module):
471
+ def __init__(self, embed_dim):
472
+ super().__init__()
473
+ self.timestep = Timestep(embed_dim)
474
+
475
+ def forward(self, x):
476
+ x = self.timestep(x)
477
+ return x
478
+
479
+
480
+ if __name__ == '__main__':
481
+ import PIL.Image as Image
482
+
483
+ encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H")
484
+ image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB")
485
+ image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2
486
+ image = image.unsqueeze(0)
487
+ print(image.shape)
488
+ feat = encoder.encode(image, "local")
489
+ print(feat.shape)
refnet/modules/encoder.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from refnet.util import checkpoint_wrapper
6
+ from refnet.modules.unet import TimestepEmbedSequential
7
+ from refnet.modules.layers import Upsample, zero_module, RMSNorm, FeedForward
8
+ from refnet.modules.attention import MemoryEfficientAttention, MultiScaleCausalAttention
9
+ from einops import rearrange
10
+ from functools import partial
11
+
12
+
13
+
14
+ def make_zero_conv(in_channels, out_channels=None):
15
+ out_channels = out_channels or in_channels
16
+ return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
17
+
18
+ def activate_zero_conv(in_channels, out_channels=None):
19
+ out_channels = out_channels or in_channels
20
+ return TimestepEmbedSequential(
21
+ nn.SiLU(),
22
+ zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
23
+ )
24
+
25
+ def sequential_downsample(in_channels, out_channels, sequential_cls=nn.Sequential):
26
+ return sequential_cls(
27
+ nn.Conv2d(in_channels, 16, 3, padding=1),
28
+ nn.SiLU(),
29
+ nn.Conv2d(16, 16, 3, padding=1),
30
+ nn.SiLU(),
31
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
32
+ nn.SiLU(),
33
+ nn.Conv2d(32, 32, 3, padding=1),
34
+ nn.SiLU(),
35
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
36
+ nn.SiLU(),
37
+ nn.Conv2d(96, 96, 3, padding=1),
38
+ nn.SiLU(),
39
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
40
+ nn.SiLU(),
41
+ zero_module(nn.Conv2d(256, out_channels, 3, padding=1))
42
+ )
43
+
44
+
45
+ class SimpleEncoder(nn.Module):
46
+ def __init__(self, c_channels, model_channels):
47
+ super().__init__()
48
+ self.model = sequential_downsample(c_channels, model_channels)
49
+
50
+ def forward(self, x, *args, **kwargs):
51
+ return self.model(x)
52
+
53
+
54
+ class MultiEncoder(nn.Module):
55
+ def __init__(self, in_ch, model_channels, ch_mults, checkpoint=True, time_embed=False):
56
+ super().__init__()
57
+ sequential_cls = TimestepEmbedSequential if time_embed else nn.Sequential
58
+ output_chs = [model_channels * mult for mult in ch_mults]
59
+ self.model = sequential_downsample(in_ch, model_channels, sequential_cls)
60
+ self.zero_layer = make_zero_conv(output_chs[0])
61
+ self.output_blocks = nn.ModuleList()
62
+ self.zero_blocks = nn.ModuleList()
63
+
64
+ block_num = len(ch_mults)
65
+ prev_ch = output_chs[0]
66
+ for i in range(block_num):
67
+ self.output_blocks.append(sequential_cls(
68
+ nn.SiLU(),
69
+ nn.Conv2d(prev_ch, output_chs[i], 3, padding=1, stride=2 if i != block_num-1 else 1),
70
+ nn.SiLU(),
71
+ nn.Conv2d(output_chs[i], output_chs[i], 3, padding=1)
72
+ ))
73
+ self.zero_blocks.append(
74
+ TimestepEmbedSequential(make_zero_conv(output_chs[i])) if time_embed
75
+ else make_zero_conv(output_chs[i])
76
+ )
77
+ prev_ch = output_chs[i]
78
+
79
+ self.checkpoint = checkpoint
80
+
81
+ def forward(self, x):
82
+ x = self.model(x)
83
+ hints = [self.zero_layer(x)]
84
+ for layer, zero_layer in zip(self.output_blocks, self.zero_blocks):
85
+ x = layer(x)
86
+ hints.append(zero_layer(x))
87
+ return hints
88
+
89
+
90
+ class MultiScaleAttentionEncoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_ch,
94
+ model_channels,
95
+ ch_mults,
96
+ dim_head = 128,
97
+ transformer_layers = 2,
98
+ checkpoint = True
99
+ ):
100
+ super().__init__()
101
+ conv_proj = partial(nn.Conv2d, kernel_size=1, padding=0)
102
+ output_chs = [model_channels * mult for mult in ch_mults]
103
+ block_num = len(ch_mults)
104
+ attn_ch = output_chs[-1]
105
+
106
+ self.model = sequential_downsample(in_ch, output_chs[0])
107
+ self.proj_ins = nn.ModuleList([conv_proj(output_chs[0], attn_ch)])
108
+ self.proj_outs = nn.ModuleList([zero_module(conv_proj(attn_ch, output_chs[0]))])
109
+
110
+ prev_ch = output_chs[0]
111
+ self.downsample_layers = nn.ModuleList()
112
+ for i in range(block_num):
113
+ ch = output_chs[i]
114
+ self.downsample_layers.append(nn.Sequential(
115
+ nn.SiLU(),
116
+ nn.Conv2d(prev_ch, ch, 3, padding=1, stride=2 if i != block_num - 1 else 1),
117
+ ))
118
+ self.proj_ins.append(conv_proj(ch, attn_ch))
119
+ self.proj_outs.append(zero_module(conv_proj(attn_ch, ch)))
120
+ prev_ch = ch
121
+
122
+ self.proj_ins.append(conv_proj(attn_ch, attn_ch))
123
+ self.attn_layer = MultiScaleCausalAttention(attn_ch, rope=True, qk_norm=True, dim_head=dim_head)
124
+ # self.transformer = nn.ModuleList([
125
+ # BasicTransformerBlock(
126
+ # attn_ch,
127
+ # rotary_positional_embedding = True,
128
+ # qk_norm = True,
129
+ # d_head = dim_head,
130
+ # disable_cross_attn = True,
131
+ # self_attn_type = "multi-scale",
132
+ # ff_mult = 2,
133
+ # )
134
+ # ] * transformer_layers)
135
+ self.checkpoint = checkpoint
136
+
137
+ @checkpoint_wrapper
138
+ def forward(self, x):
139
+ proj_in_iter = iter(self.proj_ins)
140
+ proj_out_iter = iter(self.proj_outs[::-1])
141
+
142
+ x = self.model(x)
143
+ hints = [rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c")]
144
+ grid_sizes = [(x.shape[2], x.shape[3])]
145
+ token_lens = [(x.shape[2] * x.shape[3])]
146
+
147
+ for layer in self.downsample_layers:
148
+ x = layer(x)
149
+ h, w = x.shape[2], x.shape[3]
150
+ grid_sizes.append((h, w))
151
+ token_lens.append(h * w)
152
+ hints.append(rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c"))
153
+
154
+ hints.append(rearrange(
155
+ next(proj_in_iter)(x.mean(dim=[2, 3], keepdim=True)),
156
+ "b c h w -> b (h w) c"
157
+ ))
158
+
159
+ hints = hints[::-1]
160
+ grid_sizes = grid_sizes[::-1]
161
+ token_lens = token_lens[::-1]
162
+ hints = torch.cat(hints, 1)
163
+ hints = self.attn_layer(hints, grid_size=grid_sizes, token_lens=token_lens) + hints
164
+ # for layer in self.transformer:
165
+ # hints = layer(hints, grid_size=grid_sizes, token_lens=token_lens)
166
+
167
+ prev_idx = 1
168
+ controls = []
169
+ for gs, token_len in zip(grid_sizes, token_lens):
170
+ control = hints[:, prev_idx: prev_idx + token_len]
171
+ control = rearrange(control, "b (h w) c -> b c h w", h=gs[0], w=gs[1])
172
+ controls.append(next(proj_out_iter)(control))
173
+ prev_idx = prev_idx + token_len
174
+ return controls[::-1]
175
+
176
+
177
+
178
+ class Downsampler(nn.Module):
179
+ def __init__(self, scale_factor):
180
+ super().__init__()
181
+ self.scale_factor = scale_factor
182
+
183
+ def forward(self, x):
184
+ return F.interpolate(x, scale_factor=self.scale_factor, mode="bicubic")
185
+
186
+
187
+ class SpatialConditionEncoder(nn.Module):
188
+ def __init__(
189
+ self,
190
+ in_dim,
191
+ dim,
192
+ out_dim,
193
+ patch_size,
194
+ n_layers = 4,
195
+ ):
196
+ super().__init__()
197
+ self.patch_embed = nn.Conv2d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
198
+ self.conv = nn.Sequential(nn.SiLU(), nn.Conv2d(dim, dim, kernel_size=3, padding=1))
199
+
200
+ self.transformer = nn.ModuleList(
201
+ nn.ModuleList([
202
+ RMSNorm(dim),
203
+ MemoryEfficientAttention(dim, rope=True),
204
+ RMSNorm(dim),
205
+ FeedForward(dim, mult=2)
206
+ ]) for _ in range(n_layers)
207
+ )
208
+ self.out = nn.Sequential(
209
+ nn.SiLU(),
210
+ zero_module(nn.Conv2d(dim, out_dim, kernel_size=1, padding=0))
211
+ )
212
+
213
+ def forward(self, x):
214
+ x = self.patch_embed(x)
215
+ x = self.conv(x)
216
+
217
+ b, c, h, w = x.shape
218
+ x = rearrange(x, "b c h w -> b (h w) c")
219
+ for norm, layer, norm2, ff in self.transformer:
220
+ x = layer(norm(x), grid_size=(h, w)) + x
221
+ x = ff(norm2(x)) + x
222
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
223
+
224
+ return self.out(x)
refnet/modules/layers.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from refnet.util import default
8
+
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(self, dim: int, eps: float = 1e-6):
13
+ super().__init__()
14
+ self.eps = eps
15
+ self.weight = nn.Parameter(torch.ones(dim))
16
+
17
+ def _norm(self, x):
18
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
19
+
20
+ def forward(self, x):
21
+ output = self._norm(x.float()).type_as(x)
22
+ return output * self.weight
23
+
24
+
25
+
26
+ def init_(tensor):
27
+ dim = tensor.shape[-1]
28
+ std = 1 / math.sqrt(dim)
29
+ tensor.uniform_(-std, std)
30
+ return tensor
31
+
32
+
33
+ # feedforward
34
+ class GEGLU(nn.Module):
35
+ def __init__(self, dim_in, dim_out):
36
+ super().__init__()
37
+ self.proj = nn.Linear(dim_in, dim_out * 2)
38
+
39
+ def forward(self, x):
40
+ x, gate = self.proj(x).chunk(2, dim=-1)
41
+ return x * F.gelu(gate)
42
+
43
+
44
+ class FeedForward(nn.Module):
45
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
46
+ super().__init__()
47
+ inner_dim = int(dim * mult)
48
+ dim_out = default(dim_out, dim)
49
+ project_in = nn.Sequential(
50
+ nn.Linear(dim, inner_dim),
51
+ nn.GELU()
52
+ ) if not glu else GEGLU(dim, inner_dim)
53
+
54
+ self.net = nn.Sequential(
55
+ project_in,
56
+ nn.Dropout(dropout),
57
+ nn.Linear(inner_dim, dim_out)
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.net(x)
62
+
63
+
64
+ def zero_module(module):
65
+ """
66
+ Zero out the parameters of a module and return it.
67
+ """
68
+ for p in module.parameters():
69
+ p.detach().zero_()
70
+ return module
71
+
72
+
73
+ def Normalize(in_channels):
74
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
75
+
76
+
77
+ class Upsample(nn.Module):
78
+ """
79
+ An upsampling layer with an optional convolution.
80
+ :param channels: channels in the inputs and outputs.
81
+ :param use_conv: a bool determining if a convolution is applied.
82
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
83
+ upsampling occurs in the inner-two dimensions.
84
+ """
85
+
86
+ def __init__(self, channels, use_conv, out_channels=None, padding=1):
87
+ super().__init__()
88
+ self.channels = channels
89
+ self.out_channels = out_channels or channels
90
+ self.use_conv = use_conv
91
+ if use_conv:
92
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
93
+
94
+ def forward(self, x):
95
+ assert x.shape[1] == self.channels
96
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
97
+ if self.use_conv:
98
+ x = self.conv(x)
99
+ return x
refnet/modules/lora.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Union, Dict, List
6
+ from einops import rearrange
7
+ from refnet.util import exists, default
8
+ from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
9
+
10
+
11
+ def get_module_safe(self, module_path: str):
12
+ current_module = self
13
+ try:
14
+ for part in module_path.split('.'):
15
+ current_module = getattr(current_module, part)
16
+ return current_module
17
+ except AttributeError:
18
+ raise AttributeError(f"Cannot find modules {module_path}")
19
+
20
+
21
+ def switch_lora(self, v, label=None):
22
+ for t in [self.to_q, self.to_k, self.to_v]:
23
+ t.set_lora_active(v, label)
24
+
25
+
26
+ def lora_forward(self, x, context, mask, scale=1., scale_factor= None):
27
+ def qkv_forward(x, context):
28
+ q = self.to_q(x)
29
+ k = self.to_k(context)
30
+ v = self.to_v(context)
31
+ return q, k, v
32
+
33
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
34
+
35
+ mask = rearrange(
36
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
37
+ "b c h w -> b (h w) c"
38
+ ).contiguous()
39
+
40
+ c1, c2 = context.chunk(2, dim=1)
41
+
42
+ # Background region cross-attention
43
+ if self.use_lora:
44
+ self.switch_lora(False, "foreground")
45
+ q2, k2, v2 = qkv_forward(x, c2)
46
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
47
+
48
+ # Character region cross-attention
49
+ if self.use_lora:
50
+ self.switch_lora(True, "foreground")
51
+ q1, k1, v1 = qkv_forward(x, c1)
52
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
53
+
54
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
55
+ return fg_out * mask + bg_out * (1 - mask)
56
+ # return torch.where(mask > self.mask_threshold, fg_out, bg_out)
57
+
58
+
59
+ def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None):
60
+ """
61
+ This function hacks cross-attention layers.
62
+ Args:
63
+ x: Query input
64
+ context: Key and value input
65
+ mask: Character mask
66
+ scale: Attention scale
67
+ sacle_factor: Current latent size factor
68
+
69
+ """
70
+ def qkv_forward(x, context):
71
+ q = self.to_q(x)
72
+ k = self.to_k(context)
73
+ v = self.to_v(context)
74
+ return q, k, v
75
+
76
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
77
+
78
+ mask = rearrange(
79
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
80
+ "b c h w -> b (h w) c"
81
+ ).contiguous()
82
+
83
+ c1, c2 = context.chunk(2, dim=1)
84
+
85
+ # Background region cross-attention
86
+ if self.use_lora:
87
+ self.switch_lora(True, "background")
88
+ self.switch_lora(False, "foreground")
89
+ q2, k2, v2 = qkv_forward(x, c2)
90
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
91
+
92
+ # Foreground region cross-attention
93
+ if self.use_lora:
94
+ self.switch_lora(False, "background")
95
+ self.switch_lora(True, "foreground")
96
+ q1, k1, v1 = qkv_forward(x, c1)
97
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
98
+
99
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
100
+ # return fg_out * mask + bg_out * (1 - mask)
101
+ return torch.where(mask > self.mask_threshold, fg_out, bg_out)
102
+
103
+
104
+
105
+ class MultiLoraInjectedLinear(nn.Linear):
106
+ """
107
+ A linear layer that can hold multiple LoRA adapters and merge them.
108
+ """
109
+ def __init__(
110
+ self,
111
+ in_features,
112
+ out_features,
113
+ bias = False,
114
+ ):
115
+ super().__init__(in_features, out_features, bias)
116
+ self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}}
117
+ self.lora_scales: Dict[str, float] = {}
118
+ self.active_loras: Dict[str, bool] = {}
119
+ self.original_weight = None
120
+ self.original_bias = None
121
+
122
+ # Freeze original weights
123
+ self.weight.requires_grad_(False)
124
+ if exists(self.bias):
125
+ self.bias.requires_grad_(False)
126
+
127
+ def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0):
128
+ """Add a new LoRA adapter with the given label."""
129
+ if isinstance(r, float):
130
+ r = int(r * self.out_features)
131
+
132
+ lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None)
133
+ lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None)
134
+ dropout = nn.Dropout(dropout_p)
135
+
136
+ # Initialize weights
137
+ nn.init.normal_(lora_down.weight, std=1 / r)
138
+ nn.init.zeros_(lora_up.weight)
139
+
140
+ self.lora_adapters[label] = {
141
+ 'down': lora_down,
142
+ 'up': lora_up,
143
+ 'dropout': dropout,
144
+ }
145
+ self.lora_scales[label] = scale
146
+ self.active_loras[label] = True
147
+
148
+ # Register as submodules
149
+ self.add_module(f'lora_down_{label}', lora_down)
150
+ self.add_module(f'lora_up_{label}', lora_up)
151
+ self.add_module(f'lora_dropout_{label}', dropout)
152
+
153
+ def get_trainable_layers(self, label: str = None):
154
+ """Get trainable layers for specific LoRA or all LoRAs."""
155
+ layers = []
156
+ if exists(label):
157
+ if label in self.lora_adapters:
158
+ adapter = self.lora_adapters[label]
159
+ layers.extend([adapter['down'], adapter['up']])
160
+ else:
161
+ for adapter in self.lora_adapters.values():
162
+ layers.extend([adapter['down'], adapter['up']])
163
+ return layers
164
+
165
+ def set_lora_active(self, active: bool, label: str):
166
+ """Activate or deactivate a specific LoRA adapter."""
167
+ if label in self.active_loras:
168
+ self.active_loras[label] = active
169
+
170
+ def set_lora_scale(self, scale: float, label: str):
171
+ """Set the scale for a specific LoRA adapter."""
172
+ if label in self.lora_scales:
173
+ self.lora_scales[label] = scale
174
+
175
+ def merge_lora_weights(self, labels: List[str] = None):
176
+ """Merge specified LoRA adapters into the base weights."""
177
+ if labels is None:
178
+ labels = list(self.lora_adapters.keys())
179
+
180
+ # Store original weights if not already stored
181
+ if self.original_weight is None:
182
+ self.original_weight = self.weight.clone()
183
+ if exists(self.bias):
184
+ self.original_bias = self.bias.clone()
185
+
186
+ merged_weight = self.original_weight.clone()
187
+ merged_bias = self.original_bias.clone() if exists(self.original_bias) else None
188
+
189
+ for label in labels:
190
+ if label in self.lora_adapters and self.active_loras.get(label, False):
191
+ lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down']
192
+ scale = self.lora_scales[label]
193
+
194
+ lora_weight = lora_up.weight @ lora_down.weight
195
+ merged_weight += scale * lora_weight
196
+
197
+ if exists(merged_bias) and exists(lora_up.bias):
198
+ lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias
199
+ merged_bias += scale * lora_bias
200
+
201
+ # Update weights
202
+ self.weight = nn.Parameter(merged_weight, requires_grad=False)
203
+ if exists(merged_bias):
204
+ self.bias = nn.Parameter(merged_bias, requires_grad=False)
205
+
206
+ # Deactivate all LoRAs after merging
207
+ for label in labels:
208
+ self.active_loras[label] = False
209
+
210
+ def recover_original_weight(self):
211
+ """Recover the original weights before any LoRA modifications."""
212
+ if self.original_weight is not None:
213
+ self.weight = nn.Parameter(self.original_weight.clone())
214
+ if exists(self.original_bias):
215
+ self.bias = nn.Parameter(self.original_bias.clone())
216
+
217
+ # Reactivate all LoRAs
218
+ for label in self.active_loras:
219
+ self.active_loras[label] = True
220
+
221
+ def forward(self, input):
222
+ output = super().forward(input)
223
+
224
+ # Add contributions from active LoRAs
225
+ for label, adapter in self.lora_adapters.items():
226
+ if self.active_loras.get(label, False):
227
+ lora_out = adapter['up'](adapter['dropout'](adapter['down'](input)))
228
+ output += self.lora_scales[label] * lora_out
229
+
230
+ return output
231
+
232
+
233
+ class LoraModules:
234
+ def __init__(self, sd, lora_params, *args, **kwargs):
235
+ self.modules = {}
236
+ self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer
237
+
238
+ for cfg in lora_params:
239
+ root_module = get_module_safe(sd, cfg.pop("root_module"))
240
+ label = cfg.pop("label", "lora")
241
+ self.inject_lora(label, root_module, **cfg)
242
+
243
+ def inject_lora(
244
+ self,
245
+ label,
246
+ root_module,
247
+ r,
248
+ split_forward = False,
249
+ target_keys = ("to_q", "to_k", "to_v"),
250
+ filter_keys = None,
251
+ target_class = None,
252
+ scale = 1.0,
253
+ dropout_p = 0.0,
254
+ ):
255
+ def check_condition(path, child, class_list):
256
+ if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys):
257
+ return False
258
+ if exists(target_keys) and any(path.endswith(key) for key in target_keys):
259
+ return True
260
+ if exists(class_list) and any(
261
+ isinstance(child, module_class) for module_class in class_list
262
+ ):
263
+ return True
264
+ return False
265
+
266
+ def retrieve_target_modules():
267
+ from refnet.util import get_obj_from_str
268
+ target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None
269
+
270
+ modules = []
271
+ for name, module in root_module.named_modules():
272
+ for key, child in module._modules.items():
273
+ full_path = name + '.' + key if name else key
274
+ if check_condition(full_path, child, target_class_list):
275
+ modules.append((module, child, key, full_path))
276
+ return modules
277
+
278
+ modules: list[Union[nn.Module]] = []
279
+ retrieved_modules = retrieve_target_modules()
280
+
281
+ for parent, child, child_name, full_path in retrieved_modules:
282
+ # Check if this layer already has a MultiLoraInjectedLinear
283
+ if full_path in self.multi_lora_layers:
284
+ # Add LoRA to existing MultiLoraInjectedLinear
285
+ multi_lora_layer = self.multi_lora_layers[full_path]
286
+ multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
287
+ else:
288
+ # Check if the current layer is already a MultiLoraInjectedLinear
289
+ if isinstance(child, MultiLoraInjectedLinear):
290
+ child.add_lora_adapter(label, r, scale, dropout_p)
291
+ self.multi_lora_layers[full_path] = child
292
+ else:
293
+ # Replace with MultiLoraInjectedLinear and add first LoRA
294
+ multi_lora_layer = MultiLoraInjectedLinear(
295
+ in_features=child.weight.shape[1],
296
+ out_features=child.weight.shape[0],
297
+ bias=exists(child.bias),
298
+ )
299
+
300
+ multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
301
+ parent._modules[child_name] = multi_lora_layer
302
+ self.multi_lora_layers[full_path] = multi_lora_layer
303
+
304
+ if split_forward:
305
+ parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__)
306
+ else:
307
+ parent.masked_forward = lora_forward.__get__(parent, parent.__class__)
308
+
309
+ parent.use_lora = True
310
+ parent.switch_lora = switch_lora.__get__(parent, parent.__class__)
311
+ modules.append(parent)
312
+
313
+ self.modules[label] = modules
314
+ print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers")
315
+ return self.multi_lora_layers, modules
316
+
317
+ def get_trainable_layers(self, label = None):
318
+ """Get all trainable layers, optionally filtered by label."""
319
+ layers = []
320
+ for lora_layer in self.multi_lora_layers.values():
321
+ layers += lora_layer.get_trainable_layers(label)
322
+ return layers
323
+
324
+ def switch_lora(self, mode, label = None):
325
+ if exists(label):
326
+ for layer in self.multi_lora_layers.values():
327
+ layer.set_lora_active(mode, label)
328
+ for module in self.modules[label]:
329
+ module.use_lora = mode
330
+ else:
331
+ for layer in self.multi_lora_layers.values():
332
+ for lora_label in layer.lora_adapters.keys():
333
+ layer.set_lora_active(mode, lora_label)
334
+
335
+ for modules in self.modules.values():
336
+ for module in modules:
337
+ module.use_lora = mode
338
+
339
+ def adjust_lora_scales(self, scale, label = None):
340
+ if exists(label):
341
+ for layer in self.multi_lora_layers.values():
342
+ layer.set_lora_scale(scale, label)
343
+ else:
344
+ for layer in self.multi_lora_layers.values():
345
+ for lora_label in layer.lora_adapters.keys():
346
+ layer.set_lora_scale(scale, lora_label)
347
+
348
+ def merge_lora(self, labels = None):
349
+ if labels is None:
350
+ labels = list(self.modules.keys())
351
+ elif isinstance(labels, str):
352
+ labels = [labels]
353
+
354
+ for layer in self.multi_lora_layers.values():
355
+ layer.merge_lora_weights(labels)
356
+
357
+ def recover_lora(self):
358
+ for layer in self.multi_lora_layers.values():
359
+ layer.recover_original_weight()
360
+
361
+ def get_lora_info(self):
362
+ """Get information about all LoRA adapters."""
363
+ info = {}
364
+ for path, layer in self.multi_lora_layers.items():
365
+ info[path] = {
366
+ 'labels': list(layer.lora_adapters.keys()),
367
+ 'active': {label: active for label, active in layer.active_loras.items()},
368
+ 'scales': layer.lora_scales.copy()
369
+ }
370
+ return info
refnet/modules/proj.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from refnet.modules.layers import zero_module
5
+ from refnet.modules.attention import MemoryEfficientAttention
6
+ from refnet.modules.transformer import BasicTransformerBlock
7
+ from refnet.util import checkpoint_wrapper, exists
8
+ from refnet.util import load_weights
9
+
10
+
11
+ class NormalizedLinear(nn.Module):
12
+ def __init__(self, dim, output_dim, checkpoint=True):
13
+ super().__init__()
14
+ self.layers = nn.Sequential(
15
+ nn.Linear(dim, output_dim),
16
+ nn.LayerNorm(output_dim)
17
+ )
18
+ self.checkpoint = checkpoint
19
+
20
+ @checkpoint_wrapper
21
+ def forward(self, x):
22
+ return self.layers(x)
23
+
24
+
25
+ class GlobalProjection(nn.Module):
26
+ def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True):
27
+ super().__init__()
28
+ self.c_dim = output_dim
29
+ self.dim_head = dim_head
30
+ self.head = (heads[0], heads[0] * heads[1])
31
+
32
+ self.proj1 = nn.Linear(input_dim, dim_head * heads[0])
33
+ self.proj2 = nn.Sequential(
34
+ nn.SiLU(),
35
+ zero_module(nn.Linear(dim_head, output_dim * heads[1])),
36
+ )
37
+ self.norm = nn.LayerNorm(output_dim)
38
+ self.checkpoint = checkpoint
39
+
40
+ @checkpoint_wrapper
41
+ def forward(self, x):
42
+ x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous()
43
+ x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous()
44
+ return self.norm(x)
45
+
46
+
47
+ class ClusterConcat(nn.Module):
48
+ def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True):
49
+ super().__init__()
50
+ self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head)
51
+ self.norm = nn.LayerNorm(input_dim)
52
+ self.proj = nn.Sequential(
53
+ nn.Linear(input_dim + c_dim, output_dim),
54
+ nn.SiLU(),
55
+ nn.Linear(output_dim, output_dim),
56
+ nn.LayerNorm(output_dim)
57
+ )
58
+ self.token_length = token_length
59
+ self.checkpoint = checkpoint
60
+
61
+ @checkpoint_wrapper
62
+ def forward(self, x, emb, fgbg=False, *args, **kwargs):
63
+ x = self.attn(x)[:, :self.token_length]
64
+ x = self.norm(x)
65
+ x = torch.cat([x, emb], 2)
66
+ x = self.proj(x)
67
+
68
+ if fgbg:
69
+ x = torch.cat(torch.chunk(x, 2), 1)
70
+ return x
71
+
72
+
73
+ class RecoveryClusterConcat(ClusterConcat):
74
+ def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs):
75
+ super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs)
76
+ self.transformer = BasicTransformerBlock(
77
+ output_dim, output_dim//dim_head, dim_head,
78
+ disable_cross_attn=True, checkpoint=False
79
+ )
80
+
81
+ @checkpoint_wrapper
82
+ def forward(self, x, emb, bg=False):
83
+ x = self.attn(x)[:, :self.token_length]
84
+ x = self.norm(x)
85
+ x = torch.cat([x, emb], 2)
86
+ x = self.proj(x)
87
+
88
+ if bg:
89
+ x = self.transformer(x)
90
+ return x
91
+
92
+
93
+ class LogitClusterConcat(ClusterConcat):
94
+ def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs):
95
+ super().__init__(c_dim=c_dim, *args, **kwargs)
96
+ self.mlp = AdaptiveMLP(c_dim, mlp_in_dim)
97
+ if exists(mlp_ckpt_path):
98
+ self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True)
99
+
100
+ @checkpoint_wrapper
101
+ def forward(self, x, emb, bg=False):
102
+ with torch.no_grad():
103
+ emb = self.mlp(emb).detach()
104
+ return super().forward(x, emb, bg)
105
+
106
+
107
+ class AdaptiveMLP(nn.Module):
108
+ def __init__(self, dim, in_dim, layers=4, checkpoint=True):
109
+ super().__init__()
110
+
111
+ model = [nn.Sequential(nn.Linear(in_dim, dim))]
112
+ for i in range(1, layers):
113
+ model += [nn.Sequential(
114
+ nn.SiLU(),
115
+ nn.LayerNorm(dim),
116
+ nn.Linear(dim, dim)
117
+ )]
118
+ self.mlp = nn.Sequential(*model)
119
+ self.fusion_layer = nn.Linear(dim * layers, dim, bias=False)
120
+ self.norm = nn.LayerNorm(dim)
121
+ self.checkpoint = checkpoint
122
+
123
+ @checkpoint_wrapper
124
+ def forward(self, x):
125
+ fx = []
126
+
127
+ for layer in self.mlp:
128
+ x = layer(x)
129
+ fx.append(x)
130
+
131
+ x = torch.cat(fx, dim=2)
132
+ out = self.fusion_layer(x)
133
+ out = self.norm(out)
134
+ return out
135
+
136
+
137
+ class Concat(nn.Module):
138
+ def __init__(self, *args, **kwargs):
139
+ super().__init__()
140
+
141
+ def forward(self, x, y, *args, **kwargs):
142
+ return torch.cat([x, y], dim=-1)
refnet/modules/reference_net.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange
6
+ from typing import Union
7
+ from functools import partial
8
+
9
+ from refnet.modules.unet_old import (
10
+ timestep_embedding,
11
+ conv_nd,
12
+ TimestepEmbedSequential,
13
+ exists,
14
+ ResBlock,
15
+ linear,
16
+ Downsample,
17
+ zero_module,
18
+ SelfTransformerBlock,
19
+ SpatialTransformer,
20
+ )
21
+ from refnet.modules.unet import DualCondUNetXL
22
+
23
+
24
+ def hack_inference_forward(model):
25
+ model.forward = InferenceForward.__get__(model, model.__class__)
26
+
27
+
28
+ def hack_unet_forward(unet):
29
+ unet.original_forward = unet._forward
30
+ if isinstance(unet, DualCondUNetXL):
31
+ unet._forward = enhanced_forward_xl.__get__(unet, unet.__class__)
32
+ else:
33
+ unet._forward = enhanced_forward.__get__(unet, unet.__class__)
34
+
35
+
36
+ def restore_unet_forward(unet):
37
+ if hasattr(unet, "original_forward"):
38
+ unet._forward = unet.original_forward.__get__(unet, unet.__class__)
39
+ del unet.original_forward
40
+
41
+
42
+ def modulation(x, scale, shift):
43
+ return x * (1 + scale) + shift
44
+
45
+
46
+ def enhanced_forward(
47
+ self,
48
+ x: torch.Tensor,
49
+ emb: torch.Tensor,
50
+ hs_fg: torch.Tensor = None,
51
+ hs_bg: torch.Tensor = None,
52
+ mask: torch.Tensor = None,
53
+ threshold: Union[float|torch.Tensor] = None,
54
+ control: torch.Tensor = None,
55
+ context: torch.Tensor = None,
56
+ style_modulations: torch.Tensor = None,
57
+ **additional_context
58
+ ):
59
+ h = x.to(self.dtype)
60
+ emb = emb.to(self.dtype)
61
+ hs = []
62
+
63
+ control_iter = iter(control)
64
+ for idx, module in enumerate(self.input_blocks):
65
+ h = module(h, emb, context, mask, **additional_context)
66
+
67
+ if idx in self.hint_encoder_index:
68
+ h += next(control_iter)
69
+
70
+ hs.append(h)
71
+
72
+ h = self.middle_block(h, emb, context, mask, **additional_context)
73
+
74
+ for idx, module in enumerate(self.output_blocks):
75
+ h_skip = hs.pop()
76
+
77
+ if exists(mask) and exists(threshold):
78
+ # inject foreground/background features
79
+ B, C, H, W = h_skip.shape
80
+ cm = F.interpolate(mask, (H, W), mode="bicubic")
81
+ h = torch.cat([h, torch.where(
82
+ cm > threshold,
83
+ self.map_modules[idx](h_skip, hs_fg[idx]) if exists(hs_fg) else h_skip,
84
+ self.warp_modules[idx](h_skip, hs_bg[idx]) if exists(hs_bg) else h_skip
85
+ )], 1)
86
+
87
+ else:
88
+ h = torch.cat([h, h_skip], 1)
89
+
90
+ h = module(h, emb, context, mask, **additional_context)
91
+
92
+ if exists(style_modulations):
93
+ style_norm, emb_proj, style_proj = self.style_modules[idx]
94
+ style_m = style_modulations[idx] + emb_proj(emb)
95
+ style_m = style_proj(style_norm(style_m))[...,None,None]
96
+ scale, shift = style_m.chunk(2, dim=1)
97
+
98
+ h = modulation(h, scale, shift)
99
+
100
+ return h
101
+
102
+ def enhanced_forward_xl(
103
+ self,
104
+ x: torch.Tensor,
105
+ emb,
106
+ z_fg: torch.Tensor = None,
107
+ z_bg: torch.Tensor = None,
108
+ hs_fg: torch.Tensor = None,
109
+ hs_bg: torch.Tensor = None,
110
+ mask: torch.Tensor = None,
111
+ inject_mask: torch.Tensor = None,
112
+ threshold: Union[float|torch.Tensor] = None,
113
+ concat: torch.Tensor = None,
114
+ control: torch.Tensor = None,
115
+ context: torch.Tensor = None,
116
+ style_modulations: torch.Tensor = None,
117
+ **additional_context
118
+ ):
119
+ h = x.to(self.dtype)
120
+ emb = emb.to(self.dtype)
121
+ hs = []
122
+ control_iter = iter(control)
123
+
124
+ if exists(concat):
125
+ h = torch.cat([h, concat], 1)
126
+ h = h + self.concat_conv(h)
127
+
128
+ for idx, module in enumerate(self.input_blocks):
129
+ h = module(h, emb, context, mask, **additional_context)
130
+
131
+ if idx in self.hint_encoder_index:
132
+ h += next(control_iter)
133
+
134
+ if exists(z_fg):
135
+ h += self.conv_fg(z_fg)
136
+ z_fg = None
137
+ if exists(z_bg):
138
+ h += self.conv_bg(z_bg)
139
+ z_bg = None
140
+
141
+ hs.append(h)
142
+
143
+ h = self.middle_block(h, emb, context, mask, **additional_context)
144
+
145
+ for idx, module in enumerate(self.output_blocks):
146
+ h_skip = hs.pop()
147
+
148
+ if exists(inject_mask) and exists(threshold):
149
+ # inject foreground/background features
150
+ B, C, H, W = h_skip.shape
151
+ cm = F.interpolate(inject_mask, (H, W), mode="bicubic")
152
+ h = torch.cat([h, torch.where(
153
+ cm > threshold,
154
+
155
+ # foreground injection
156
+ rearrange(
157
+ self.map_modules[idx][0](
158
+ rearrange(h_skip, "b c h w -> b (h w) c"),
159
+ hs_fg[idx] + self.map_modules[idx][1](emb).unsqueeze(1)
160
+ ), "b (h w) c -> b c h w", h=H, w=W
161
+ ) + h_skip if exists(hs_fg) else h_skip,
162
+
163
+ # background injection
164
+ rearrange(
165
+ self.warp_modules[idx][0](
166
+ rearrange(h_skip, "b c h w -> b (h w) c"),
167
+ hs_bg[idx] + self.warp_modules[idx][1](emb).unsqueeze(1)
168
+ ), "b (h w) c -> b c h w", h=H, w=W
169
+ ) + h_skip if exists(hs_bg) else h_skip
170
+ )], 1)
171
+
172
+ else:
173
+ h = torch.cat([h, h_skip], 1)
174
+
175
+ h = module(h, emb, context, mask, **additional_context)
176
+
177
+ if exists(style_modulations):
178
+ style_norm, emb_proj, style_proj = self.style_modules[idx]
179
+ style_m = style_modulations[idx] + emb_proj(emb)
180
+ style_m = style_proj(style_norm(style_m))[...,None,None]
181
+ scale, shift = style_m.chunk(2, dim=1)
182
+
183
+ h = modulation(h, scale, shift)
184
+
185
+ if idx in self.hint_decoder_index:
186
+ h += next(control_iter)
187
+
188
+ return h
189
+
190
+ def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
191
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
192
+ emb = self.time_embed(t_emb).to(self.dtype)
193
+ assert (y is not None) == (
194
+ self.num_classes is not None
195
+ ), "must specify y if and only if the model is class-conditional"
196
+
197
+ if self.num_classes is not None:
198
+ assert y.shape[0] == x.shape[0]
199
+ emb = emb + self.label_emb(y.to(self.dtype))
200
+ emb = emb.to(self.dtype)
201
+ return self._forward(x, emb, *args, **kwargs)
202
+
203
+
204
+ class UNetEncoderXL(nn.Module):
205
+ transformers = {
206
+ "vanilla": SpatialTransformer,
207
+ }
208
+
209
+ def __init__(
210
+ self,
211
+ in_channels,
212
+ model_channels,
213
+ num_res_blocks,
214
+ attention_resolutions,
215
+ dropout = 0,
216
+ channel_mult = (1, 2, 4, 8),
217
+ conv_resample = True,
218
+ dims = 2,
219
+ num_classes = None,
220
+ use_checkpoint = False,
221
+ num_heads = -1,
222
+ num_head_channels = -1,
223
+ use_scale_shift_norm = False,
224
+ resblock_updown = False,
225
+ use_spatial_transformer = False, # custom transformer support
226
+ transformer_depth = 1, # custom transformer support
227
+ context_dim = None, # custom transformer support
228
+ disable_self_attentions = None,
229
+ disable_cross_attentions = None,
230
+ num_attention_blocks = None,
231
+ use_linear_in_transformer = False,
232
+ adm_in_channels = None,
233
+ transformer_type = "vanilla",
234
+ style_modulation = False,
235
+ ):
236
+ super().__init__()
237
+ if use_spatial_transformer:
238
+ assert exists(
239
+ context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
240
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
241
+ from omegaconf.listconfig import ListConfig
242
+ if type(context_dim) == ListConfig:
243
+ context_dim = list(context_dim)
244
+
245
+ time_embed_dim = model_channels * 4
246
+ resblock = partial(
247
+ ResBlock,
248
+ emb_channels=time_embed_dim,
249
+ dropout=dropout,
250
+ dims=dims,
251
+ use_checkpoint=use_checkpoint,
252
+ use_scale_shift_norm=use_scale_shift_norm,
253
+ )
254
+
255
+ transformer = partial(
256
+ self.transformers[transformer_type],
257
+ context_dim=context_dim,
258
+ use_linear=use_linear_in_transformer,
259
+ use_checkpoint=use_checkpoint,
260
+ disable_self_attn=disable_self_attentions,
261
+ disable_cross_attn=disable_cross_attentions,
262
+ )
263
+
264
+ if num_heads == -1:
265
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
266
+
267
+ if num_head_channels == -1:
268
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
269
+ self.in_channels = in_channels
270
+ self.model_channels = model_channels
271
+ if isinstance(num_res_blocks, int):
272
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
273
+ else:
274
+ if len(num_res_blocks) != len(channel_mult):
275
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
276
+ "as a list/tuple (per-level) with the same length as channel_mult")
277
+ self.num_res_blocks = num_res_blocks
278
+ if disable_self_attentions is not None:
279
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
280
+ assert len(disable_self_attentions) == len(channel_mult)
281
+ if num_attention_blocks is not None:
282
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
283
+ assert all(
284
+ map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
285
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
286
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
287
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
288
+ f"attention will still not be set.")
289
+
290
+ self.attention_resolutions = attention_resolutions
291
+ self.dropout = dropout
292
+ self.channel_mult = channel_mult
293
+ self.conv_resample = conv_resample
294
+ self.num_classes = num_classes
295
+ self.use_checkpoint = use_checkpoint
296
+ self.dtype = torch.float32
297
+ self.num_heads = num_heads
298
+ self.num_head_channels = num_head_channels
299
+ self.style_modulation = style_modulation
300
+
301
+ if isinstance(transformer_depth, int):
302
+ transformer_depth = len(channel_mult) * [transformer_depth]
303
+
304
+ time_embed_dim = model_channels * 4
305
+ zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
306
+
307
+ self.time_embed = nn.Sequential(
308
+ linear(model_channels, time_embed_dim),
309
+ nn.SiLU(),
310
+ linear(time_embed_dim, time_embed_dim),
311
+ )
312
+
313
+ if self.num_classes is not None:
314
+ if isinstance(self.num_classes, int):
315
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
316
+ elif self.num_classes == "continuous":
317
+ print("setting up linear c_adm embedding layer")
318
+ self.label_emb = nn.Linear(1, time_embed_dim)
319
+ elif self.num_classes == "sequential":
320
+ assert adm_in_channels is not None
321
+ self.label_emb = nn.Sequential(
322
+ nn.Sequential(
323
+ linear(adm_in_channels, time_embed_dim),
324
+ nn.SiLU(),
325
+ linear(time_embed_dim, time_embed_dim),
326
+ )
327
+ )
328
+ else:
329
+ raise ValueError()
330
+
331
+ self.input_blocks = nn.ModuleList(
332
+ [
333
+ TimestepEmbedSequential(
334
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
335
+ )
336
+ ]
337
+ )
338
+ self.zero_layers = nn.ModuleList([zero_module(
339
+ nn.Linear(model_channels, model_channels * 2) if style_modulation else
340
+ zero_conv(model_channels, model_channels)
341
+ )])
342
+
343
+ ch = model_channels
344
+ ds = 1
345
+ for level, mult in enumerate(channel_mult):
346
+ for nr in range(self.num_res_blocks[level]):
347
+ layers = [
348
+ ResBlock(
349
+ ch,
350
+ time_embed_dim,
351
+ dropout,
352
+ out_channels=mult * model_channels,
353
+ dims=dims,
354
+ use_checkpoint=use_checkpoint,
355
+ use_scale_shift_norm=use_scale_shift_norm,
356
+ )
357
+ ]
358
+ ch = mult * model_channels
359
+ if ds in attention_resolutions:
360
+ num_heads = ch // num_head_channels
361
+
362
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
363
+ layers.append(
364
+ SelfTransformerBlock(ch, num_head_channels)
365
+ if not use_spatial_transformer
366
+ else transformer(
367
+ ch, num_heads, num_head_channels, depth=transformer_depth[level]
368
+ )
369
+ )
370
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
371
+ self.zero_layers.append(zero_module(
372
+ nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
373
+ ))
374
+
375
+ if level != len(channel_mult) - 1:
376
+ out_ch = ch
377
+ self.input_blocks.append(TimestepEmbedSequential(
378
+ ResBlock(
379
+ ch,
380
+ time_embed_dim,
381
+ dropout,
382
+ out_channels=out_ch,
383
+ dims=dims,
384
+ use_checkpoint=use_checkpoint,
385
+ use_scale_shift_norm=use_scale_shift_norm,
386
+ down=True,
387
+ ) if resblock_updown else Downsample(
388
+ ch, conv_resample, dims=dims, out_channels=out_ch
389
+ )
390
+ ))
391
+ self.zero_layers.append(zero_module(
392
+ nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
393
+ zero_conv(out_ch, out_ch)
394
+ ))
395
+ ch = out_ch
396
+ ds *= 2
397
+
398
+
399
+ def forward(self, x, timesteps = None, y = None, *args, **kwargs):
400
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
401
+ emb = self.time_embed(t_emb)
402
+
403
+ assert (y is not None) == (
404
+ self.num_classes is not None
405
+ ), "must specify y if and only if the model is class-conditional"
406
+ if self.num_classes is not None:
407
+ assert y.shape[0] == x.shape[0]
408
+ emb = emb + self.label_emb(y.to(self.dtype))
409
+
410
+ hs = self._forward(x, emb, *args, **kwargs)
411
+ return hs
412
+
413
+ def _forward(self, x, emb, context = None, **additional_context):
414
+ hints = []
415
+ h = x.to(self.dtype)
416
+
417
+ for idx, module in enumerate(self.input_blocks):
418
+ h = module(h, emb, context, **additional_context)
419
+
420
+ if self.style_modulation:
421
+ hint = self.zero_layers[idx](h.mean(dim=[2, 3]))
422
+ hints.append(hint)
423
+
424
+ else:
425
+ hint = self.zero_layers[idx](h)
426
+ hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
427
+ hints.append(hint)
428
+
429
+ hints.reverse()
430
+ return hints
refnet/modules/transformer.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from functools import partial
5
+ from einops import rearrange
6
+
7
+ from refnet.util import checkpoint_wrapper, exists
8
+ from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm
9
+ from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ ATTENTION_MODES = {
14
+ "vanilla": MemoryEfficientAttention,
15
+ "multi-scale": MultiScaleCausalAttention,
16
+ "multi-modal": MultiModalAttention,
17
+ }
18
+ def __init__(
19
+ self,
20
+ dim,
21
+ n_heads = None,
22
+ d_head = 64,
23
+ dropout = 0.,
24
+ context_dim = None,
25
+ gated_ff = True,
26
+ ff_mult = 4,
27
+ checkpoint = True,
28
+ disable_self_attn = False,
29
+ disable_cross_attn = False,
30
+ self_attn_type = "vanilla",
31
+ cross_attn_type = "vanilla",
32
+ rotary_positional_embedding = False,
33
+ context_dim_2 = None,
34
+ casual_self_attn = False,
35
+ casual_cross_attn = False,
36
+ qk_norm = False,
37
+ norm_type = "layer",
38
+ ):
39
+ super().__init__()
40
+ assert self_attn_type in self.ATTENTION_MODES
41
+ assert cross_attn_type in self.ATTENTION_MODES
42
+ self_attn_cls = self.ATTENTION_MODES[self_attn_type]
43
+ crossattn_cls = self.ATTENTION_MODES[cross_attn_type]
44
+
45
+ if norm_type == "layer":
46
+ norm_cls = nn.LayerNorm
47
+ elif norm_type == "rms":
48
+ norm_cls = RMSNorm
49
+ else:
50
+ raise NotImplementedError(f"Normalization {norm_type} is not implemented.")
51
+
52
+ self.dim = dim
53
+ self.disable_self_attn = disable_self_attn
54
+ self.disable_cross_attn = disable_cross_attn
55
+
56
+ self.attn1 = self_attn_cls(
57
+ query_dim = dim,
58
+ heads = n_heads,
59
+ dim_head = d_head,
60
+ dropout = dropout,
61
+ context_dim = context_dim if self.disable_self_attn else None,
62
+ casual = casual_self_attn,
63
+ rope = rotary_positional_embedding,
64
+ qk_norm = qk_norm
65
+ )
66
+ self.attn2 = crossattn_cls(
67
+ query_dim = dim,
68
+ context_dim = context_dim,
69
+ context_dim_2 = context_dim_2,
70
+ heads = n_heads,
71
+ dim_head = d_head,
72
+ dropout = dropout,
73
+ casual = casual_cross_attn
74
+ ) if not disable_cross_attn else None
75
+
76
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult)
77
+ self.norm1 = norm_cls(dim)
78
+ self.norm2 = norm_cls(dim) if not disable_cross_attn else None
79
+ self.norm3 = norm_cls(dim)
80
+ self.reference_scale = 1
81
+ self.scale_factor = None
82
+ self.checkpoint = checkpoint
83
+
84
+ @checkpoint_wrapper
85
+ def forward(self, x, context=None, mask=None, emb=None, **kwargs):
86
+ x = self.attn1(self.norm1(x), **kwargs) + x
87
+ if not self.disable_cross_attn:
88
+ x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x
89
+ x = self.ff(self.norm3(x)) + x
90
+ return x
91
+
92
+
93
+ class SelfInjectedTransformerBlock(BasicTransformerBlock):
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.bank = None
97
+ self.time_proj = None
98
+ self.injection_type = "concat"
99
+ self.forward_without_bank = super().forward
100
+
101
+ @checkpoint_wrapper
102
+ def forward(self, x, context=None, mask=None, emb=None, **kwargs):
103
+ if exists(self.bank):
104
+ bank = self.bank
105
+ if bank.shape[0] != x.shape[0]:
106
+ bank = bank.repeat(x.shape[0], 1, 1)
107
+ if exists(self.time_proj) and exists(emb):
108
+ bank = bank + self.time_proj(emb).unsqueeze(1)
109
+ x_in = self.norm1(x)
110
+
111
+ self.attn1.mask_threshold = self.attn2.mask_threshold
112
+ x = self.attn1(
113
+ x_in,
114
+ torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank,
115
+ mask = mask,
116
+ scale_factor = self.scale_factor,
117
+ **kwargs
118
+ ) + x
119
+
120
+ x = self.attn2(
121
+ self.norm2(x),
122
+ context,
123
+ mask = mask,
124
+ scale = self.reference_scale,
125
+ scale_factor = self.scale_factor
126
+ ) + x
127
+
128
+ x = self.ff(self.norm3(x)) + x
129
+ else:
130
+ x = self.forward_without_bank(x, context, mask, emb)
131
+ return x
132
+
133
+
134
+ class SelfTransformerBlock(nn.Module):
135
+ def __init__(
136
+ self,
137
+ dim,
138
+ dim_head = 64,
139
+ dropout = 0.,
140
+ mlp_ratio = 4,
141
+ checkpoint = True,
142
+ casual_attn = False,
143
+ reshape = True
144
+ ):
145
+ super().__init__()
146
+ self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn)
147
+ self.ff = nn.Sequential(
148
+ nn.Linear(dim, dim * mlp_ratio),
149
+ nn.SiLU(),
150
+ zero_module(nn.Linear(dim * mlp_ratio, dim))
151
+ )
152
+ self.norm1 = nn.LayerNorm(dim)
153
+ self.norm2 = nn.LayerNorm(dim)
154
+ self.reshape = reshape
155
+ self.checkpoint = checkpoint
156
+
157
+ @checkpoint_wrapper
158
+ def forward(self, x, context=None):
159
+ b, c, h, w = x.shape
160
+ if self.reshape:
161
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
162
+
163
+ x = self.attn(self.norm1(x), context if exists(context) else None) + x
164
+ x = self.ff(self.norm2(x)) + x
165
+
166
+ if self.reshape:
167
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
168
+ return x
169
+
170
+
171
+ class Transformer(nn.Module):
172
+ transformer_type = {
173
+ "vanilla": BasicTransformerBlock,
174
+ "self-injection": SelfInjectedTransformerBlock,
175
+ }
176
+ def __init__(self, in_channels, n_heads, d_head,
177
+ depth=1, dropout=0., context_dim=None, use_linear=False,
178
+ use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs):
179
+ super().__init__()
180
+ transformer_block = self.transformer_type[type]
181
+ if not isinstance(context_dim, list):
182
+ context_dim = [context_dim]
183
+ if isinstance(context_dim, list):
184
+ if depth != len(context_dim):
185
+ context_dim = depth * [context_dim[0]]
186
+
187
+ proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
188
+ inner_dim = n_heads * d_head
189
+
190
+ self.in_channels = in_channels
191
+ self.proj_in = proj_layer(in_channels, inner_dim)
192
+ self.transformer_blocks = nn.ModuleList([
193
+ transformer_block(
194
+ inner_dim,
195
+ n_heads,
196
+ d_head,
197
+ dropout = dropout,
198
+ context_dim = context_dim[d],
199
+ checkpoint = use_checkpoint,
200
+ **(transformer_config or {}),
201
+ **kwargs
202
+ ) for d in range(depth)
203
+ ])
204
+ self.proj_out = zero_module(proj_layer(inner_dim, in_channels))
205
+ self.norm = Normalize(in_channels)
206
+ self.use_linear = use_linear
207
+
208
+ def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context):
209
+ # note: if no context is given, cross-attention defaults to self-attention
210
+ b, c, h, w = x.shape
211
+ x_in = x
212
+ x = self.norm(x)
213
+ if not self.use_linear:
214
+ x = self.proj_in(x)
215
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
216
+ if self.use_linear:
217
+ x = self.proj_in(x)
218
+ for i, block in enumerate(self.transformer_blocks):
219
+ x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context)
220
+ if self.use_linear:
221
+ x = self.proj_out(x)
222
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
223
+ if not self.use_linear:
224
+ x = self.proj_out(x)
225
+ return x + x_in
226
+
227
+
228
+ def SpatialTransformer(*args, **kwargs):
229
+ return Transformer(type="vanilla", *args, **kwargs)
230
+
231
+ def SelfInjectTransformer(*args, **kwargs):
232
+ return Transformer(type="self-injection", *args, **kwargs)
refnet/modules/unet.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from functools import partial
5
+ from refnet.modules.attention import MemoryEfficientAttention
6
+ from refnet.util import exists
7
+ from refnet.modules.transformer import (
8
+ SelfTransformerBlock,
9
+ Transformer,
10
+ SpatialTransformer,
11
+ SelfInjectTransformer,
12
+ )
13
+ from refnet.ldm.openaimodel import (
14
+ timestep_embedding,
15
+ conv_nd,
16
+ TimestepBlock,
17
+ zero_module,
18
+ ResBlock,
19
+ linear,
20
+ Downsample,
21
+ Upsample,
22
+ normalization,
23
+ )
24
+
25
+
26
+ def hack_inference_forward(model):
27
+ model.forward = InferenceForward.__get__(model, model.__class__)
28
+
29
+
30
+ def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
31
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
32
+ emb = self.time_embed(t_emb).to(self.dtype)
33
+ assert (y is not None) == (
34
+ self.num_classes is not None
35
+ ), "must specify y if and only if the model is class-conditional"
36
+ if self.num_classes is not None:
37
+ assert y.shape[0] == x.shape[0]
38
+ emb = emb + self.label_emb(y.to(emb.device))
39
+ emb = emb.to(self.dtype)
40
+ h = self._forward(x, emb, *args, **kwargs)
41
+ return self.out(h.to(x.dtype))
42
+
43
+
44
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
45
+ """
46
+ A sequential module that passes timestep embeddings to the children that
47
+ support it as an extra input.
48
+ """
49
+ # Dispatch constants
50
+ _D_TIMESTEP = 0
51
+ _D_TRANSFORMER = 1
52
+ _D_OTHER = 2
53
+
54
+ def __init__(self, *args, **kwargs):
55
+ super().__init__(*args, **kwargs)
56
+ # Cache dispatch types at init (before FSDP wrapping), so forward()
57
+ # needs no isinstance checks and is immune to FSDP wrapper breakage.
58
+ self._dispatch = tuple(
59
+ self._D_TIMESTEP if isinstance(layer, TimestepBlock) else
60
+ self._D_TRANSFORMER if isinstance(layer, Transformer) else
61
+ self._D_OTHER
62
+ for layer in self
63
+ )
64
+
65
+ def forward(self, x, emb=None, context=None, mask=None, **additional_context):
66
+ for layer, d in zip(self, self._dispatch):
67
+ if d == self._D_TIMESTEP:
68
+ x = layer(x, emb)
69
+ elif d == self._D_TRANSFORMER:
70
+ x = layer(x, context, mask, emb, **additional_context)
71
+ else:
72
+ x = layer(x)
73
+ return x
74
+
75
+
76
+
77
+ class UNetModel(nn.Module):
78
+ transformers = {
79
+ "vanilla": SpatialTransformer,
80
+ "selfinj": SelfInjectTransformer,
81
+ }
82
+ def __init__(
83
+ self,
84
+ in_channels,
85
+ model_channels,
86
+ num_res_blocks,
87
+ attention_resolutions,
88
+ out_channels = 4,
89
+ dropout = 0,
90
+ channel_mult = (1, 2, 4, 8),
91
+ conv_resample = True,
92
+ dims = 2,
93
+ num_classes = None,
94
+ use_checkpoint = False,
95
+ num_heads = -1,
96
+ num_head_channels = -1,
97
+ use_scale_shift_norm = False,
98
+ resblock_updown = False,
99
+ use_spatial_transformer = False, # custom transformer support
100
+ transformer_depth = 1, # custom transformer support
101
+ context_dim = None, # custom transformer support
102
+ disable_self_attentions = None,
103
+ disable_cross_attentions = False,
104
+ num_attention_blocks = None,
105
+ use_linear_in_transformer = False,
106
+ adm_in_channels = None,
107
+ transformer_type = "vanilla",
108
+ map_module = False,
109
+ warp_module = False,
110
+ style_modulation = False,
111
+ discard_final_layers = False, # for reference net
112
+ additional_transformer_config = None,
113
+ in_channels_fg = None,
114
+ in_channels_bg = None,
115
+ ):
116
+ super().__init__()
117
+ if context_dim is not None:
118
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
119
+ from omegaconf.listconfig import ListConfig
120
+ if type(context_dim) == ListConfig:
121
+ context_dim = list(context_dim)
122
+
123
+ assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
124
+ if isinstance(num_res_blocks, int):
125
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
126
+ else:
127
+ if len(num_res_blocks) != len(channel_mult):
128
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
129
+ "as a list/tuple (per-level) with the same length as channel_mult")
130
+ self.num_res_blocks = num_res_blocks
131
+ if disable_self_attentions is not None:
132
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
133
+ assert len(disable_self_attentions) == len(channel_mult)
134
+ if num_attention_blocks is not None:
135
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
136
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
137
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
138
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
139
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
140
+ f"attention will still not be set.")
141
+
142
+ self.num_classes = num_classes
143
+ self.model_channels = model_channels
144
+ self.dtype = torch.float32
145
+
146
+ if isinstance(transformer_depth, int):
147
+ transformer_depth = len(channel_mult) * [transformer_depth]
148
+ transformer_depth_middle = transformer_depth[-1]
149
+ time_embed_dim = model_channels * 4
150
+ resblock = partial(
151
+ ResBlock,
152
+ emb_channels = time_embed_dim,
153
+ dropout = dropout,
154
+ dims = dims,
155
+ use_checkpoint = use_checkpoint,
156
+ use_scale_shift_norm = use_scale_shift_norm,
157
+ )
158
+ transformer = partial(
159
+ self.transformers[transformer_type],
160
+ context_dim = context_dim,
161
+ use_linear = use_linear_in_transformer,
162
+ use_checkpoint = use_checkpoint,
163
+ disable_self_attn = disable_self_attentions,
164
+ disable_cross_attn = disable_cross_attentions,
165
+ transformer_config = additional_transformer_config
166
+ )
167
+
168
+ self.time_embed = nn.Sequential(
169
+ linear(model_channels, time_embed_dim),
170
+ nn.SiLU(),
171
+ linear(time_embed_dim, time_embed_dim),
172
+ )
173
+
174
+ if self.num_classes is not None:
175
+ if isinstance(self.num_classes, int):
176
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
177
+ elif self.num_classes == "continuous":
178
+ print("setting up linear c_adm embedding layer")
179
+ self.label_emb = nn.Linear(1, time_embed_dim)
180
+ elif self.num_classes == "sequential":
181
+ assert adm_in_channels is not None
182
+ self.label_emb = nn.Sequential(
183
+ nn.Sequential(
184
+ linear(adm_in_channels, time_embed_dim),
185
+ nn.SiLU(),
186
+ linear(time_embed_dim, time_embed_dim),
187
+ )
188
+ )
189
+ else:
190
+ raise ValueError()
191
+
192
+ self.input_blocks = nn.ModuleList([
193
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
194
+ ])
195
+ input_block_chans = [model_channels]
196
+ ch = model_channels
197
+ ds = 1
198
+ for level, mult in enumerate(channel_mult):
199
+ for nr in range(self.num_res_blocks[level]):
200
+ layers = [resblock(ch, out_channels=mult * model_channels)]
201
+ ch = mult * model_channels
202
+ if ds in attention_resolutions:
203
+ if num_head_channels > -1:
204
+ current_num_heads = ch // num_head_channels
205
+ current_head_dim = num_head_channels
206
+ else:
207
+ current_num_heads = num_heads
208
+ current_head_dim = ch // num_heads
209
+
210
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
211
+ layers.append(
212
+ SelfTransformerBlock(ch, current_head_dim)
213
+ if not use_spatial_transformer
214
+ else transformer(
215
+ ch, current_num_heads, current_head_dim,
216
+ depth=transformer_depth[level],
217
+ )
218
+ )
219
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
220
+ input_block_chans.append(ch)
221
+ if level != len(channel_mult) - 1:
222
+ out_ch = ch
223
+ self.input_blocks.append(TimestepEmbedSequential(
224
+ resblock(ch, out_channels=out_ch, down=True) if resblock_updown
225
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
226
+ ))
227
+ ch = out_ch
228
+ input_block_chans.append(ch)
229
+ ds *= 2
230
+
231
+ if num_head_channels > -1:
232
+ current_num_heads = ch // num_head_channels
233
+ current_head_dim = num_head_channels
234
+ else:
235
+ current_num_heads = num_heads
236
+ current_head_dim = ch // num_heads
237
+ self.middle_block = TimestepEmbedSequential(
238
+ resblock(ch),
239
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
240
+ else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
241
+ resblock(ch),
242
+ )
243
+
244
+ self.output_blocks = nn.ModuleList([])
245
+ self.map_modules = nn.ModuleList([])
246
+ self.warp_modules = nn.ModuleList([])
247
+ self.style_modules = nn.ModuleList([])
248
+
249
+ for level, mult in list(enumerate(channel_mult))[::-1]:
250
+ for i in range(self.num_res_blocks[level] + 1):
251
+ ich = input_block_chans.pop()
252
+ layers = [resblock(ch + ich, out_channels=model_channels * mult)]
253
+ ch = model_channels * mult
254
+ if ds in attention_resolutions:
255
+ if num_head_channels > -1:
256
+ current_num_heads = ch // num_head_channels
257
+ current_head_dim = num_head_channels
258
+ else:
259
+ current_num_heads = num_heads
260
+ current_head_dim = ch // num_heads
261
+
262
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
263
+ layers.append(
264
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
265
+ else transformer(
266
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
267
+ )
268
+ )
269
+ if level and i == self.num_res_blocks[level]:
270
+ out_ch = ch
271
+ layers.append(
272
+ resblock(ch, up=True) if resblock_updown else Upsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch
274
+ )
275
+ )
276
+ ds //= 2
277
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
278
+ if level == 0 and discard_final_layers:
279
+ break
280
+
281
+ if map_module:
282
+ self.map_modules.append(nn.ModuleList([
283
+ MemoryEfficientAttention(
284
+ ich,
285
+ heads = ich // num_head_channels,
286
+ dim_head = num_head_channels
287
+ ),
288
+ nn.Linear(time_embed_dim, ich)
289
+ ]))
290
+
291
+ if warp_module:
292
+ self.warp_modules.append(nn.ModuleList([
293
+ MemoryEfficientAttention(
294
+ ich,
295
+ heads = ich // num_head_channels,
296
+ dim_head = num_head_channels
297
+ ),
298
+ nn.Linear(time_embed_dim, ich)
299
+ ]))
300
+
301
+ # self.warp_modules.append(nn.ModuleList([
302
+ # SpatialTransformer(ich, ich//num_head_channels, num_head_channels),
303
+ # nn.Linear(time_embed_dim, ich)
304
+ # ]))
305
+
306
+ if style_modulation:
307
+ self.style_modules.append(nn.ModuleList([
308
+ nn.LayerNorm(ch*2),
309
+ nn.Linear(time_embed_dim, ch*2),
310
+ zero_module(nn.Linear(ch*2, ch*2))
311
+ ]))
312
+
313
+ if not discard_final_layers:
314
+ self.out = nn.Sequential(
315
+ normalization(ch),
316
+ nn.SiLU(),
317
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
318
+ )
319
+
320
+ self.conv_fg = zero_module(
321
+ conv_nd(dims, in_channels_fg, model_channels, 3, padding=1)
322
+ ) if exists(in_channels_fg) else None
323
+ self.conv_bg = zero_module(
324
+ conv_nd(dims, in_channels_bg, model_channels, 3, padding=1)
325
+ ) if exists(in_channels_bg) else None
326
+
327
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
328
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
329
+ emb = self.time_embed(t_emb)
330
+ assert (y is not None) == (
331
+ self.num_classes is not None
332
+ ), "must specify y if and only if the model is class-conditional"
333
+ if self.num_classes is not None:
334
+ assert y.shape[0] == x.shape[0]
335
+ emb = emb + self.label_emb(y.to(self.dtype))
336
+
337
+ h = self._forward(x, emb, *args, **kwargs)
338
+ return self.out(h).to(x.dtype)
339
+
340
+ def _forward(
341
+ self,
342
+ x,
343
+ emb,
344
+ control = None,
345
+ context = None,
346
+ mask = None,
347
+ **additional_context
348
+ ):
349
+ hs = []
350
+ h = x.to(self.dtype)
351
+
352
+ for module in self.input_blocks:
353
+ h = module(h, emb, context, mask, **additional_context)
354
+ hs.append(h)
355
+
356
+ h = self.middle_block(h, emb, context, mask, **additional_context)
357
+
358
+ for module in self.output_blocks:
359
+ h = torch.cat([h, hs.pop()], dim=1)
360
+ h = module(h, emb, context, mask, **additional_context)
361
+ return h
362
+
363
+
364
+ class DualCondUNetXL(UNetModel):
365
+ def __init__(
366
+ self,
367
+ hint_encoder_index = (0, 3, 6, 8),
368
+ hint_decoder_index = (),
369
+ *args,
370
+ **kwargs
371
+ ):
372
+ super().__init__(*args, **kwargs)
373
+ self.hint_encoder_index = hint_encoder_index
374
+ self.hint_decoder_index = hint_decoder_index
375
+
376
+ def _forward(self, x, emb, concat=None, control=None, context=None, mask=None, **additional_context):
377
+ h = x.to(self.dtype)
378
+ hs = []
379
+
380
+ if exists(concat):
381
+ h = torch.cat([h, concat], 1)
382
+
383
+ control_iter = iter(control)
384
+ for idx, module in enumerate(self.input_blocks):
385
+ h = module(h, emb, context, mask, **additional_context)
386
+
387
+ if idx in self.hint_encoder_index:
388
+ h += next(control_iter)
389
+ hs.append(h)
390
+
391
+ h = self.middle_block(h, emb, context, mask, **additional_context)
392
+
393
+ for idx, module in enumerate(self.output_blocks):
394
+ h = torch.cat([h, hs.pop()], dim=1)
395
+ h = module(h, emb, context, mask, **additional_context)
396
+
397
+ if idx in self.hint_decoder_index:
398
+ h += next(control_iter)
399
+
400
+ return h
401
+
402
+
403
+ class ReferenceNet(UNetModel):
404
+ def __init__(self, *args, **kwargs):
405
+ super().__init__(discard_final_layers=True, *args, **kwargs)
406
+
407
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
408
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
409
+ emb = self.time_embed(t_emb)
410
+
411
+ assert (y is not None) == (
412
+ self.num_classes is not None
413
+ ), "must specify y if and only if the model is class-conditional"
414
+ if self.num_classes is not None:
415
+ assert y.shape[0] == x.shape[0]
416
+ emb = emb + self.label_emb(y.to(self.dtype))
417
+ self._forward(x, emb, *args, **kwargs)
418
+
419
+ def _forward(self, *args, **kwargs):
420
+ super()._forward(*args, **kwargs)
421
+ return None
refnet/modules/unet_old.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from functools import partial
5
+ from refnet.util import exists
6
+ from refnet.modules.transformer import (
7
+ SelfTransformerBlock,
8
+ Transformer,
9
+ SpatialTransformer,
10
+ rearrange
11
+ )
12
+ from refnet.ldm.openaimodel import (
13
+ timestep_embedding,
14
+ conv_nd,
15
+ TimestepBlock,
16
+ zero_module,
17
+ ResBlock,
18
+ linear,
19
+ Downsample,
20
+ Upsample,
21
+ normalization,
22
+ )
23
+
24
+ try:
25
+ import xformers
26
+ import xformers.ops
27
+ XFORMERS_IS_AVAILBLE = True
28
+ except:
29
+ XFORMERS_IS_AVAILBLE = False
30
+
31
+
32
+ def hack_inference_forward(model):
33
+ model.forward = InferenceForward.__get__(model, model.__class__)
34
+
35
+ def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
36
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
37
+ emb = self.time_embed(t_emb).to(self.dtype)
38
+ assert (y is not None) == (
39
+ self.num_classes is not None
40
+ ), "must specify y if and only if the model is class-conditional"
41
+
42
+ if self.num_classes is not None:
43
+ assert y.shape[0] == x.shape[0]
44
+ emb = emb + self.label_emb(y.to(emb.device))
45
+ emb = emb.to(self.dtype)
46
+ h = self._forward(x, emb, *args, **kwargs)
47
+ return self.out(h.to(x.dtype))
48
+
49
+
50
+
51
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
52
+ """
53
+ A sequential module that passes timestep embeddings to the children that
54
+ support it as an extra input.
55
+ """
56
+
57
+ def forward(self, x, emb, context=None, mask=None, **additional_context):
58
+ for layer in self:
59
+ if isinstance(layer, TimestepBlock):
60
+ x = layer(x, emb)
61
+ elif isinstance(layer, Transformer):
62
+ x = layer(x, context, mask, **additional_context)
63
+ else:
64
+ x = layer(x)
65
+ return x
66
+
67
+
68
+
69
+ class UNetModel(nn.Module):
70
+ transformers = {
71
+ "vanilla": SpatialTransformer,
72
+ }
73
+ def __init__(
74
+ self,
75
+ in_channels,
76
+ model_channels,
77
+ out_channels,
78
+ num_res_blocks,
79
+ attention_resolutions,
80
+ dropout = 0,
81
+ channel_mult = (1, 2, 4, 8),
82
+ conv_resample = True,
83
+ dims = 2,
84
+ num_classes = None,
85
+ use_checkpoint = False,
86
+ num_heads = -1,
87
+ num_head_channels = -1,
88
+ use_scale_shift_norm = False,
89
+ resblock_updown = False,
90
+ use_spatial_transformer = False, # custom transformer support
91
+ transformer_depth = 1, # custom transformer support
92
+ context_dim = None, # custom transformer support
93
+ disable_self_attentions = None,
94
+ disable_cross_attentions = None,
95
+ num_attention_blocks = None,
96
+ use_linear_in_transformer = False,
97
+ adm_in_channels = None,
98
+ transformer_type = "vanilla",
99
+ map_module = False,
100
+ warp_module = False,
101
+ style_modulation = False,
102
+ ):
103
+ super().__init__()
104
+ if use_spatial_transformer:
105
+ assert exists(context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
106
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
107
+ if context_dim is not None:
108
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
109
+ from omegaconf.listconfig import ListConfig
110
+ if type(context_dim) == ListConfig:
111
+ context_dim = list(context_dim)
112
+
113
+ assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
114
+ if isinstance(num_res_blocks, int):
115
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
116
+ else:
117
+ if len(num_res_blocks) != len(channel_mult):
118
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
119
+ "as a list/tuple (per-level) with the same length as channel_mult")
120
+ self.num_res_blocks = num_res_blocks
121
+ if disable_self_attentions is not None:
122
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
123
+ assert len(disable_self_attentions) == len(channel_mult)
124
+ if num_attention_blocks is not None:
125
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
126
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
127
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
128
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
129
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
130
+ f"attention will still not be set.")
131
+
132
+ self.num_classes = num_classes
133
+ self.model_channels = model_channels
134
+ self.dtype = torch.float32
135
+
136
+ if isinstance(transformer_depth, int):
137
+ transformer_depth = len(channel_mult) * [transformer_depth]
138
+ transformer_depth_middle = transformer_depth[-1]
139
+ time_embed_dim = model_channels * 4
140
+ resblock = partial(
141
+ ResBlock,
142
+ emb_channels=time_embed_dim,
143
+ dropout=dropout,
144
+ dims=dims,
145
+ use_checkpoint=use_checkpoint,
146
+ use_scale_shift_norm=use_scale_shift_norm,
147
+ )
148
+ transformer = partial(
149
+ self.transformers[transformer_type],
150
+ context_dim=context_dim,
151
+ use_linear=use_linear_in_transformer,
152
+ use_checkpoint=use_checkpoint,
153
+ disable_self_attn=disable_self_attentions,
154
+ disable_cross_attn=disable_cross_attentions,
155
+ )
156
+
157
+ self.time_embed = nn.Sequential(
158
+ linear(model_channels, time_embed_dim),
159
+ nn.SiLU(),
160
+ linear(time_embed_dim, time_embed_dim),
161
+ )
162
+
163
+ if self.num_classes is not None:
164
+ if isinstance(self.num_classes, int):
165
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
166
+ elif self.num_classes == "continuous":
167
+ print("setting up linear c_adm embedding layer")
168
+ self.label_emb = nn.Linear(1, time_embed_dim)
169
+ elif self.num_classes == "sequential":
170
+ assert adm_in_channels is not None
171
+ self.label_emb = nn.Sequential(
172
+ nn.Sequential(
173
+ linear(adm_in_channels, time_embed_dim),
174
+ nn.SiLU(),
175
+ linear(time_embed_dim, time_embed_dim),
176
+ )
177
+ )
178
+ else:
179
+ raise ValueError()
180
+
181
+ self.input_blocks = nn.ModuleList([
182
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
183
+ ])
184
+ input_block_chans = [model_channels]
185
+ ch = model_channels
186
+ ds = 1
187
+ for level, mult in enumerate(channel_mult):
188
+ for nr in range(self.num_res_blocks[level]):
189
+ layers = [resblock(ch, out_channels=mult * model_channels)]
190
+ ch = mult * model_channels
191
+ if ds in attention_resolutions:
192
+ if num_head_channels > -1:
193
+ current_num_heads = ch // num_head_channels
194
+ current_head_dim = num_head_channels
195
+ else:
196
+ current_num_heads = num_heads
197
+ current_head_dim = ch // num_heads
198
+
199
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
200
+ layers.append(
201
+ SelfTransformerBlock(ch, current_head_dim)
202
+ if not use_spatial_transformer
203
+ else transformer(
204
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
205
+ )
206
+ )
207
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
208
+ input_block_chans.append(ch)
209
+ if level != len(channel_mult) - 1:
210
+ out_ch = ch
211
+ self.input_blocks.append(TimestepEmbedSequential(
212
+ resblock(ch, out_channels=out_ch, down=True) if resblock_updown
213
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
214
+ ))
215
+ ch = out_ch
216
+ input_block_chans.append(ch)
217
+ ds *= 2
218
+
219
+ if num_head_channels > -1:
220
+ current_num_heads = ch // num_head_channels
221
+ current_head_dim = num_head_channels
222
+ else:
223
+ current_num_heads = num_heads
224
+ current_head_dim = ch // num_heads
225
+ self.middle_block = TimestepEmbedSequential(
226
+ resblock(ch),
227
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
228
+ else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
229
+ resblock(ch),
230
+ )
231
+
232
+ self.output_blocks = nn.ModuleList([])
233
+ self.map_modules = nn.ModuleList([])
234
+ self.warp_modules = nn.ModuleList([])
235
+ self.style_modules = nn.ModuleList([])
236
+
237
+ for level, mult in list(enumerate(channel_mult))[::-1]:
238
+ for i in range(self.num_res_blocks[level] + 1):
239
+ ich = input_block_chans.pop()
240
+ layers = [resblock(ch + ich, out_channels=model_channels * mult)]
241
+ ch = model_channels * mult
242
+ if ds in attention_resolutions:
243
+ if num_head_channels > -1:
244
+ current_num_heads = ch // num_head_channels
245
+ current_head_dim = num_head_channels
246
+ else:
247
+ current_num_heads = num_heads
248
+ current_head_dim = ch // num_heads
249
+
250
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
251
+ layers.append(
252
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
253
+ else transformer(
254
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
255
+ )
256
+ )
257
+ if level and i == self.num_res_blocks[level]:
258
+ out_ch = ch
259
+ layers.append(
260
+ resblock(ch, up=True) if resblock_updown else Upsample(
261
+ ch, conv_resample, dims=dims, out_channels=out_ch
262
+ )
263
+ )
264
+ ds //= 2
265
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
266
+
267
+ if map_module:
268
+ self.map_modules.append(
269
+ SelfTransformerBlock(ich)
270
+ )
271
+
272
+ if warp_module:
273
+ self.warp_modules.append(
274
+ SelfTransformerBlock(ich)
275
+ )
276
+
277
+ if style_modulation:
278
+ self.style_modules.append(nn.ModuleList([
279
+ nn.LayerNorm(ch*2),
280
+ nn.Linear(time_embed_dim, ch*2),
281
+ zero_module(nn.Linear(ch*2, ch*2))
282
+ ]))
283
+
284
+ self.out = nn.Sequential(
285
+ normalization(ch),
286
+ nn.SiLU(),
287
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
288
+ )
289
+
290
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
291
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
292
+ emb = self.time_embed(t_emb)
293
+ assert (y is not None) == (
294
+ self.num_classes is not None
295
+ ), "must specify y if and only if the model is class-conditional"
296
+ if self.num_classes is not None:
297
+ assert y.shape[0] == x.shape[0]
298
+ emb = emb + self.label_emb(y.to(self.dtype))
299
+
300
+ h = self._forward(x, emb, *args, **kwargs)
301
+ return self.out(h).to(x.dtype)
302
+
303
+ def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
304
+ hs = []
305
+ h = x.to(self.dtype)
306
+ for module in self.input_blocks:
307
+ h = module(h, emb, context, mask, **additional_context)
308
+ hs.append(h)
309
+
310
+ h = self.middle_block(h, emb, context, mask, **additional_context)
311
+
312
+ for module in self.output_blocks:
313
+ h = torch.cat([h, hs.pop()], dim=1)
314
+ h = module(h, emb, context, mask, **additional_context)
315
+ return h
316
+
317
+
318
+ class DualCondUNet(UNetModel):
319
+ def __init__(self, *args, **kwargs):
320
+ super().__init__(*args, **kwargs)
321
+ self.hint_encoder_index = [0, 3, 6, 9, 11]
322
+
323
+ def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
324
+ h = x.to(self.dtype)
325
+ hs = []
326
+
327
+ control_iter = iter(control)
328
+ for idx, module in enumerate(self.input_blocks):
329
+ h = module(h, emb, context, mask, **additional_context)
330
+
331
+ if idx in self.hint_encoder_index:
332
+ h += next(control_iter)
333
+ hs.append(h)
334
+
335
+ h = self.middle_block(h, emb, context, mask, **additional_context)
336
+
337
+ for idx, module in enumerate(self.output_blocks):
338
+ h = torch.cat([h, hs.pop()], dim=1)
339
+ h = module(h, emb, context, mask, **additional_context)
340
+
341
+ return h
342
+
343
+ class OldUnet(UNetModel):
344
+ def __init__(self, c_channels, model_channels, channel_mult, *args, **kwargs):
345
+ super().__init__(channel_mult=channel_mult, model_channels=model_channels, *args, **kwargs)
346
+ """
347
+ Semantic condition input blocks, implementation from ControlNet.
348
+ Paper: Adding Conditional Control to Text-to-Image Diffusion Models
349
+ Authors: Lvmin Zhang, Anyi Rao, and Maneesh Agrawala
350
+ Code link: https://github.com/lllyasviel/ControlNet
351
+ """
352
+ from refnet.modules.encoder import SimpleEncoder, MultiEncoder
353
+ # self.semantic_input_blocks = SimpleEncoder(c_channels, model_channels)
354
+ self.semantic_input_blocks = MultiEncoder(c_channels, model_channels, channel_mult)
355
+ self.hint_encoder_index = [0, 3, 6, 9, 11]
356
+
357
+ def forward(self, x, timesteps=None, control=None, context=None, y=None, **kwargs):
358
+ concat = control[0].to(self.dtype)
359
+ context = context.to(self.dtype)
360
+
361
+ assert (y is not None) == (
362
+ self.num_classes is not None
363
+ ), "must specify y if and only if the model is class-conditional"
364
+ hs = []
365
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
366
+ emb = self.time_embed(t_emb).to(self.dtype)
367
+
368
+ if self.num_classes is not None:
369
+ assert y.shape[0] == x.shape[0]
370
+ emb = emb + self.label_emb(y)
371
+
372
+ h = x.to(self.dtype)
373
+ hints = self.semantic_input_blocks(concat, emb, context)
374
+
375
+ for idx, module in enumerate(self.input_blocks):
376
+ h = module(h, emb, context)
377
+ if idx in self.hint_encoder_index:
378
+ h += hints.pop(0)
379
+
380
+ hs.append(h)
381
+
382
+ h = self.middle_block(h, emb, context)
383
+
384
+ for module in self.output_blocks:
385
+ h = torch.cat([h, hs.pop()], dim=1)
386
+ h = module(h, emb, context)
387
+ h = h.to(x.dtype)
388
+ return self.out(h)
389
+
390
+
391
+ class UNetEncoder(nn.Module):
392
+ transformers = {
393
+ "vanilla": SpatialTransformer,
394
+ }
395
+
396
+ def __init__(
397
+ self,
398
+ in_channels,
399
+ model_channels,
400
+ num_res_blocks,
401
+ attention_resolutions,
402
+ dropout = 0,
403
+ channel_mult = (1, 2, 4, 8),
404
+ conv_resample = True,
405
+ dims = 2,
406
+ num_classes = None,
407
+ use_checkpoint = False,
408
+ num_heads = -1,
409
+ num_head_channels = -1,
410
+ use_scale_shift_norm = False,
411
+ resblock_updown = False,
412
+ use_spatial_transformer = False, # custom transformer support
413
+ transformer_depth = 1, # custom transformer support
414
+ context_dim = None, # custom transformer support
415
+ disable_self_attentions = None,
416
+ disable_cross_attentions = None,
417
+ num_attention_blocks = None,
418
+ use_linear_in_transformer = False,
419
+ adm_in_channels = None,
420
+ transformer_type = "vanilla",
421
+ style_modulation = False,
422
+ ):
423
+ super().__init__()
424
+ if use_spatial_transformer:
425
+ assert exists(
426
+ context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
427
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
428
+ from omegaconf.listconfig import ListConfig
429
+ if type(context_dim) == ListConfig:
430
+ context_dim = list(context_dim)
431
+
432
+
433
+ if num_heads == -1:
434
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
435
+
436
+ if num_head_channels == -1:
437
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
438
+ self.in_channels = in_channels
439
+ self.model_channels = model_channels
440
+ if isinstance(num_res_blocks, int):
441
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
442
+ else:
443
+ if len(num_res_blocks) != len(channel_mult):
444
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
445
+ "as a list/tuple (per-level) with the same length as channel_mult")
446
+ self.num_res_blocks = num_res_blocks
447
+ if disable_self_attentions is not None:
448
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
449
+ assert len(disable_self_attentions) == len(channel_mult)
450
+ if num_attention_blocks is not None:
451
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
452
+ assert all(
453
+ map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
454
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
455
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
456
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
457
+ f"attention will still not be set.")
458
+
459
+ self.attention_resolutions = attention_resolutions
460
+ self.dropout = dropout
461
+ self.channel_mult = channel_mult
462
+ self.conv_resample = conv_resample
463
+ self.num_classes = num_classes
464
+ self.use_checkpoint = use_checkpoint
465
+ self.dtype = torch.float32
466
+ self.num_heads = num_heads
467
+ self.num_head_channels = num_head_channels
468
+ self.style_modulation = style_modulation
469
+
470
+ if isinstance(transformer_depth, int):
471
+ transformer_depth = len(channel_mult) * [transformer_depth]
472
+
473
+ time_embed_dim = model_channels * 4
474
+
475
+ resblock = partial(
476
+ ResBlock,
477
+ emb_channels=time_embed_dim,
478
+ dropout=dropout,
479
+ dims=dims,
480
+ use_checkpoint=use_checkpoint,
481
+ use_scale_shift_norm=use_scale_shift_norm,
482
+ )
483
+
484
+ transformer = partial(
485
+ self.transformers[transformer_type],
486
+ context_dim=context_dim,
487
+ use_linear=use_linear_in_transformer,
488
+ use_checkpoint=use_checkpoint,
489
+ disable_self_attn=disable_self_attentions,
490
+ disable_cross_attn=disable_cross_attentions,
491
+ )
492
+
493
+ zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
494
+
495
+ self.time_embed = nn.Sequential(
496
+ linear(model_channels, time_embed_dim),
497
+ nn.SiLU(),
498
+ linear(time_embed_dim, time_embed_dim),
499
+ )
500
+
501
+ if self.num_classes is not None:
502
+ if isinstance(self.num_classes, int):
503
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
504
+ elif self.num_classes == "continuous":
505
+ print("setting up linear c_adm embedding layer")
506
+ self.label_emb = nn.Linear(1, time_embed_dim)
507
+ elif self.num_classes == "sequential":
508
+ assert adm_in_channels is not None
509
+ self.label_emb = nn.Sequential(
510
+ nn.Sequential(
511
+ linear(adm_in_channels, time_embed_dim),
512
+ nn.SiLU(),
513
+ linear(time_embed_dim, time_embed_dim),
514
+ )
515
+ )
516
+ else:
517
+ raise ValueError()
518
+
519
+ self.input_blocks = nn.ModuleList(
520
+ [
521
+ TimestepEmbedSequential(
522
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
523
+ )
524
+ ]
525
+ )
526
+ self.zero_layers = nn.ModuleList([zero_module(
527
+ nn.Linear(model_channels, model_channels * 2) if style_modulation else
528
+ zero_conv(model_channels, model_channels)
529
+ )])
530
+
531
+ ch = model_channels
532
+ ds = 1
533
+ for level, mult in enumerate(channel_mult):
534
+ for nr in range(self.num_res_blocks[level]):
535
+ layers = [resblock(ch, out_channels=mult * model_channels)]
536
+ ch = mult * model_channels
537
+ if ds in attention_resolutions:
538
+ num_heads = ch // num_head_channels
539
+
540
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
541
+ layers.append(
542
+ SelfTransformerBlock(ch, num_head_channels)
543
+ if not use_spatial_transformer
544
+ else transformer(
545
+ ch, num_heads, num_head_channels, depth=transformer_depth[level]
546
+ )
547
+ )
548
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
549
+ self.zero_layers.append(zero_module(
550
+ nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
551
+ ))
552
+
553
+ if level != len(channel_mult) - 1:
554
+ out_ch = ch
555
+ self.input_blocks.append(TimestepEmbedSequential(
556
+ resblock(ch, out_channels=mult * model_channels, down=True) if resblock_updown else Downsample(
557
+ ch, conv_resample, dims=dims, out_channels=out_ch
558
+ )
559
+ ))
560
+ self.zero_layers.append(zero_module(
561
+ nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
562
+ zero_conv(out_ch, out_ch)
563
+ ))
564
+ ch = out_ch
565
+ ds *= 2
566
+
567
+ def forward(self, x, timesteps = None, y = None, *args, **kwargs):
568
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
569
+ emb = self.time_embed(t_emb)
570
+
571
+ assert (y is not None) == (
572
+ self.num_classes is not None
573
+ ), "must specify y if and only if the model is class-conditional"
574
+ if self.num_classes is not None:
575
+ assert y.shape[0] == x.shape[0]
576
+ emb = emb + self.label_emb(y.to(self.dtype))
577
+
578
+ hs = self._forward(x, emb, *args, **kwargs)
579
+ return hs
580
+
581
+ def _forward(self, x, emb, context = None, **additional_context):
582
+ hints = []
583
+ h = x.to(self.dtype)
584
+
585
+ for zero_layer, module in zip(self.zero_layers, self.input_blocks):
586
+ h = module(h, emb, context, **additional_context)
587
+
588
+ if self.style_modulation:
589
+ hint = zero_layer(h.mean(dim=[2, 3]))
590
+ else:
591
+ hint = zero_layer(h)
592
+ hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
593
+ hints.append(hint)
594
+
595
+ hints.reverse()
596
+ return hints
refnet/sampling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .denoiser import CFGDenoiser, DiffuserDenoiser
2
+ from .hook import UnetHook, torch_dfs
3
+ from .tps_transformation import tps_warp
4
+ from .sampler import KDiffusionSampler, kdiffusion_sampler_list
5
+ from .scheduler import get_noise_schedulers
6
+
7
+ def get_sampler_list():
8
+ sampler_list = [
9
+ "diffuser_" + k for k in DiffuserDenoiser.scheduler_types.keys()
10
+ ] + kdiffusion_sampler_list()
11
+ return sorted(sampler_list)
refnet/sampling/denoiser.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import inspect
5
+ import os.path as osp
6
+ from typing import Union, Optional
7
+ from tqdm import tqdm
8
+ from omegaconf import OmegaConf
9
+ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
10
+ from diffusers.schedulers import (
11
+ DDIMScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ PNDMScheduler,
14
+ LMSDiscreteScheduler,
15
+ )
16
+
17
+ def exists(v):
18
+ return v is not None
19
+
20
+
21
+
22
+ class CFGDenoiser(nn.Module):
23
+ """
24
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
25
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
26
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
27
+ negative prompt.
28
+ """
29
+
30
+ def __init__(self, model, device):
31
+ super().__init__()
32
+ denoiser = CompVisDenoiser if model.parameterization == "eps" else CompVisVDenoiser
33
+ self.model_wrap = denoiser(model, device=device)
34
+
35
+ @property
36
+ def inner_model(self):
37
+ return self.model_wrap
38
+
39
+ def forward(
40
+ self,
41
+ x,
42
+ sigma,
43
+ cond: dict,
44
+ cond_scale: Union[float, list[float]]
45
+ ):
46
+ """
47
+ Simplify k-diffusion sampler for sketch colorizaiton.
48
+ Available for reference CFG / sketch CFG or Dual CFG
49
+ """
50
+ if not isinstance(cond_scale, list):
51
+ if cond_scale > 1.:
52
+ repeats = 2
53
+ else:
54
+ return self.inner_model(x, sigma, cond=cond)
55
+ else:
56
+ repeats = 3
57
+
58
+ x_in = torch.cat([x] * repeats)
59
+ sigma_in = torch.cat([sigma] * repeats)
60
+ x_out = self.inner_model(x_in, sigma_in, cond=cond).chunk(repeats)
61
+
62
+ if repeats == 2:
63
+ x_cond, x_uncond = x_out[:]
64
+ return x_uncond + (x_cond - x_uncond) * cond_scale
65
+ else:
66
+ x_cond, x_uncond_0, x_uncond_1 = x_out[:]
67
+ return (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
68
+ x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
69
+
70
+
71
+
72
+
73
+ scheduler_config_path = "configs/scheduler_cfgs"
74
+ class DiffuserDenoiser:
75
+ scheduler_types = {
76
+ "ddim": DDIMScheduler,
77
+ "dpm": DPMSolverMultistepScheduler,
78
+ "dpm_sde": DPMSolverMultistepScheduler,
79
+ "pndm": PNDMScheduler,
80
+ "lms": LMSDiscreteScheduler
81
+ }
82
+ def __init__(self, scheduler_type, prediction_type, use_karras=False):
83
+ scheduler_type = scheduler_type.replace("diffuser_", "")
84
+ assert scheduler_type in self.scheduler_types.keys(), "Selected scheduler is not implemented"
85
+ scheduler = self.scheduler_types[scheduler_type]
86
+ scheduler_config = OmegaConf.load(osp.abspath(osp.join(scheduler_config_path, scheduler_type + ".yaml")))
87
+ if "use_karras_sigmas" in set(inspect.signature(scheduler).parameters.keys()):
88
+ scheduler_config.use_karras_sigmas = use_karras
89
+ self.scheduler = scheduler(prediction_type=prediction_type, **scheduler_config)
90
+
91
+ def prepare_extra_step_kwargs(self, generator, eta):
92
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
93
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
94
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
95
+ # and should be between [0, 1]
96
+
97
+ accepts_eta = "eta" in set(
98
+ inspect.signature(self.scheduler.step).parameters.keys()
99
+ )
100
+ extra_step_kwargs = {}
101
+ if accepts_eta:
102
+ extra_step_kwargs["eta"] = eta
103
+
104
+ # check if the scheduler accepts generator
105
+ accepts_generator = "generator" in set(
106
+ inspect.signature(self.scheduler.step).parameters.keys()
107
+ )
108
+ if accepts_generator:
109
+ extra_step_kwargs["generator"] = generator
110
+ return extra_step_kwargs
111
+
112
+ def __call__(
113
+ self,
114
+ x,
115
+ cond,
116
+ cond_scale,
117
+ unet,
118
+ timesteps,
119
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
120
+ eta: float = 0.0,
121
+ device: str = "cuda"
122
+ ):
123
+ self.scheduler.set_timesteps(timesteps, device=device)
124
+ timesteps = self.scheduler.timesteps
125
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
126
+
127
+ x_start = x
128
+ x = x * self.scheduler.init_noise_sigma
129
+ inpaint_latents = cond.pop("inpaint_bg", None)
130
+
131
+ if exists(inpaint_latents):
132
+ mask = cond.get("mask", None)
133
+ threshold = cond.pop("threshold", 0.5)
134
+ inpaint_latents = inpaint_latents[0]
135
+ assert exists(mask)
136
+ mask = mask[0]
137
+ mask = torch.where(mask > threshold, torch.ones_like(mask), torch.zeros_like(mask))
138
+
139
+ for i, t in enumerate(tqdm(timesteps)):
140
+ x_t = self.scheduler.scale_model_input(x, t)
141
+
142
+ if not isinstance(cond_scale, list):
143
+ if cond_scale > 1.:
144
+ repeats = 2
145
+ else:
146
+ repeats = 1
147
+ else:
148
+ repeats = 3
149
+
150
+ x_in = torch.cat([x_t] * repeats)
151
+ x_out = unet.apply_model(
152
+ x_in,
153
+ t[None].expand(x_in.shape[0]),
154
+ cond=cond
155
+ )
156
+
157
+ if repeats == 1:
158
+ pred = x_out
159
+
160
+ elif repeats == 2:
161
+ x_cond, x_uncond = x_out.chunk(2)
162
+ pred = x_uncond + (x_cond - x_uncond) * cond_scale
163
+
164
+ else:
165
+ x_cond, x_uncond_0, x_uncond_1 = x_out.chunk(3)
166
+ pred = (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
167
+ x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
168
+
169
+ x = self.scheduler.step(
170
+ pred, t, x, **extra_step_kwargs, return_dict=False
171
+ )[0]
172
+
173
+ if exists(inpaint_latents) and exists(mask) and i < len(timesteps) - 1:
174
+ noise_timestep = timesteps[i + 1]
175
+ init_latents_proper = inpaint_latents
176
+ init_latents_proper = self.scheduler.add_noise(
177
+ init_latents_proper, x_start, torch.tensor([noise_timestep])
178
+ )
179
+ x = (1 - mask) * init_latents_proper + mask * x
180
+
181
+ return x
refnet/sampling/hook.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
5
+ from refnet.util import checkpoint_wrapper
6
+
7
+ """
8
+ This implementation refers to Multi-ControlNet, thanks for the authors
9
+ Paper: Adding Conditional Control to Text-to-Image Diffusion Models
10
+ Link: https://github.com/Mikubill/sd-webui-controlnet
11
+ """
12
+
13
+ def exists(v):
14
+ return v is not None
15
+
16
+ def torch_dfs(model: nn.Module):
17
+ result = [model]
18
+ for child in model.children():
19
+ result += torch_dfs(child)
20
+ return result
21
+
22
+ class AutoMachine():
23
+ Read = "read"
24
+ Write = "write"
25
+
26
+
27
+ """
28
+ This class controls the attentions of reference unet and denoising unet
29
+ """
30
+ class ReferenceAttentionControl:
31
+ writer_modules = []
32
+ reader_modules = []
33
+ def __init__(
34
+ self,
35
+ reader_module,
36
+ writer_module,
37
+ time_embed_ch = 0,
38
+ only_decoder = True,
39
+ *args,
40
+ **kwargs
41
+ ):
42
+ self.time_embed_ch = time_embed_ch
43
+ self.trainable_layers = []
44
+ self.only_decoder = only_decoder
45
+ self.hooked = False
46
+
47
+ self.register("read", reader_module)
48
+ self.register("write", writer_module)
49
+
50
+ if time_embed_ch > 0:
51
+ self.insert_time_emb_proj(reader_module)
52
+
53
+ def insert_time_emb_proj(self, unet):
54
+ for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
55
+ if isinstance(module, BasicTransformerBlock):
56
+ module.time_proj = nn.Linear(self.time_embed_ch, module.dim)
57
+ self.trainable_layers.append(module.time_proj)
58
+
59
+ def register(self, mode, unet):
60
+ @checkpoint_wrapper
61
+ def transformer_forward_write(self, x, context=None, mask=None, emb=None, **kwargs):
62
+ x_in = self.norm1(x)
63
+ x = self.attn1(x_in) + x
64
+
65
+ if not self.disable_cross_attn:
66
+ x = self.attn2(self.norm2(x), context) + x
67
+ x = self.ff(self.norm3(x)) + x
68
+
69
+ self.bank = x_in
70
+ return x
71
+
72
+ @checkpoint_wrapper
73
+ def transformer_forward_read(self, x, context=None, mask=None, emb=None, **kwargs):
74
+ if exists(self.bank):
75
+ bank = self.bank
76
+ if bank.shape[0] != x.shape[0]:
77
+ bank = bank.repeat(x.shape[0], 1, 1)
78
+ if hasattr(self, "time_proj"):
79
+ bank = bank + self.time_proj(emb).unsqueeze(1)
80
+ x_in = self.norm1(x)
81
+
82
+ x = self.attn1(
83
+ x = x_in,
84
+ context = torch.cat([x_in, bank], 1),
85
+ mask = mask,
86
+ scale_factor = self.scale_factor,
87
+ **kwargs
88
+ ) + x
89
+
90
+ x = self.attn2(
91
+ x = self.norm2(x),
92
+ context = context,
93
+ mask = mask,
94
+ scale = self.reference_scale,
95
+ scale_factor = self.scale_factor
96
+ ) + x
97
+
98
+ x = self.ff(self.norm3(x)) + x
99
+ else:
100
+ x = self.original_forward(x, context, mask, emb)
101
+ return x
102
+
103
+ assert mode in ["write", "read"]
104
+
105
+ if mode == "read":
106
+ self.hooked = True
107
+ for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
108
+ if isinstance(module, BasicTransformerBlock):
109
+ if mode == "write":
110
+ module.original_forward = module.forward
111
+ module.forward = transformer_forward_write.__get__(module, BasicTransformerBlock)
112
+ self.writer_modules.append(module)
113
+ else:
114
+ if not isinstance(module, SelfInjectedTransformerBlock):
115
+ print(f"Hooking transformer block {module.__class__.__name__} for read mode")
116
+ module.original_forward = module.forward
117
+ module.forward = transformer_forward_read.__get__(module, BasicTransformerBlock)
118
+ self.reader_modules.append(module)
119
+
120
+ def update(self):
121
+ for idx in range(len(self.writer_modules)):
122
+ self.reader_modules[idx].bank = self.writer_modules[idx].bank
123
+
124
+ def restore(self):
125
+ for idx in range(len(self.writer_modules)):
126
+ self.writer_modules[idx].forward = self.writer_modules[idx].original_forward
127
+ self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
128
+ self.reader_modules[idx].bank = None
129
+ self.hooked = False
130
+
131
+ def clean(self):
132
+ for idx in range(len(self.reader_modules)):
133
+ self.reader_modules[idx].bank = None
134
+ for idx in range(len(self.writer_modules)):
135
+ self.writer_modules[idx].bank = None
136
+ self.hooked = False
137
+
138
+ def reader_restore(self):
139
+ for idx in range(len(self.reader_modules)):
140
+ self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
141
+ self.reader_modules[idx].bank = None
142
+ self.hooked = False
143
+
144
+ def get_trainable_layers(self):
145
+ return self.trainable_layers
146
+
147
+
148
+ """
149
+ This class is for self-injection inside the denoising unet
150
+ """
151
+ class UnetHook:
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.attention_auto_machine = AutoMachine.Read
155
+
156
+ def enhance_reference(
157
+ self,
158
+ model,
159
+ ldm,
160
+ bs,
161
+ s,
162
+ r,
163
+ style_cfg=0.5,
164
+ control_cfg=0,
165
+ gr_indice=None,
166
+ injection=False,
167
+ start_step=0,
168
+ ):
169
+ def forward(self, x, t, control, context, **kwargs):
170
+ if 1 - t[0] / (ldm.num_timesteps - 1) >= outer.start_step:
171
+ # Write
172
+ outer.attention_auto_machine = AutoMachine.Write
173
+
174
+ rx = ldm.add_noise(outer.r.cpu(), torch.round(t.float()).long().cpu()).cuda().to(x.dtype)
175
+ self.original_forward(rx, t, control=outer.s, context=context, **kwargs)
176
+
177
+ # Read
178
+ outer.attention_auto_machine = AutoMachine.Read
179
+ return self.original_forward(x, t, control=control, context=context, **kwargs)
180
+
181
+ def hacked_basic_transformer_inner_forward(self, x, context=None, mask=None, emb=None, **kwargs):
182
+ x_norm1 = self.norm1(x)
183
+ self_attn1 = None
184
+ if self.disable_self_attn:
185
+ # Do not use self-attention
186
+ self_attn1 = self.attn1(x_norm1, context=context, **kwargs)
187
+
188
+ else:
189
+ # Use self-attention
190
+ self_attention_context = x_norm1
191
+ if outer.attention_auto_machine == AutoMachine.Write:
192
+ self.bank.append(self_attention_context.detach().clone())
193
+ self.style_cfgs.append(outer.current_style_fidelity)
194
+ if outer.attention_auto_machine == AutoMachine.Read:
195
+ if len(self.bank) > 0:
196
+ style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
197
+ self_attn1_uc = self.attn1(
198
+ x_norm1,
199
+ context=torch.cat([self_attention_context] + self.bank, dim=1),
200
+ **kwargs
201
+ )
202
+ self_attn1_c = self_attn1_uc.clone()
203
+ if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
204
+ self_attn1_c[outer.current_uc_indices] = self.attn1(
205
+ x_norm1[outer.current_uc_indices],
206
+ context=self_attention_context[outer.current_uc_indices],
207
+ **kwargs
208
+ )
209
+ self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
210
+ self.bank = []
211
+ self.style_cfgs = []
212
+ if self_attn1 is None:
213
+ self_attn1 = self.attn1(x_norm1, context=self_attention_context)
214
+
215
+ x = self_attn1.to(x.dtype) + x
216
+ x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor, **kwargs) + x
217
+ x = self.ff(self.norm3(x)) + x
218
+ return x
219
+
220
+ self.s = [s.repeat(bs, 1, 1, 1) * control_cfg for s in ldm.control_encoder(s)]
221
+ self.r = r
222
+ self.injection = injection
223
+ self.start_step = start_step
224
+ self.current_uc_indices = gr_indice
225
+ self.current_style_fidelity = style_cfg
226
+
227
+ outer = self
228
+ model = model.diffusion_model
229
+ model.original_forward = model.forward
230
+ # TODO: change the class name to target
231
+ model.forward = forward.__get__(model, model.__class__)
232
+ all_modules = torch_dfs(model)
233
+
234
+ for module in all_modules:
235
+ if isinstance(module, BasicTransformerBlock):
236
+ module._unet_hook_original_forward = module.forward
237
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
238
+ module.bank = []
239
+ module.style_cfgs = []
240
+
241
+
242
+ def restore(self, model):
243
+ model = model.diffusion_model
244
+ if hasattr(model, "original_forward"):
245
+ model.forward = model.original_forward
246
+ del model.original_forward
247
+
248
+ all_modules = torch_dfs(model)
249
+ for module in all_modules:
250
+ if isinstance(module, BasicTransformerBlock):
251
+ if hasattr(module, "_unet_hook_original_forward"):
252
+ module.forward = module._unet_hook_original_forward
253
+ del module._unet_hook_original_forward
254
+ if hasattr(module, "bank"):
255
+ module.bank = None
256
+ if hasattr(module, "style_cfgs"):
257
+ del module.style_cfgs
refnet/sampling/manipulation.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]):
8
+ """
9
+ The shape of input scales tensor should be (b, n, 1)
10
+ """
11
+ assert len(s.shape) == 3, len(thresholds) == 4
12
+ maxm = s.max(dim=1, keepdim=True).values
13
+ minm = s.min(dim=1, keepdim=True).values
14
+ d = maxm - minm
15
+
16
+ maxmin = (s - minm) / d
17
+
18
+ adjust_scale = torch.where(maxmin <= thresholds[0],
19
+ -dscale * ratio,
20
+ -dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0]))
21
+ adjust_scale = torch.where(maxmin > thresholds[1],
22
+ 0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]),
23
+ adjust_scale)
24
+ adjust_scale = torch.where(maxmin > thresholds[2],
25
+ 0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]),
26
+ adjust_scale)
27
+ adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale)
28
+ return adjust_scale
29
+
30
+
31
+ def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]):
32
+ # print(f"target:{t}, anchor:{a}")
33
+ cls_token = v[:, 0].unsqueeze(1)
34
+ v = v[:, 1:]
35
+
36
+ cur_target_scale = clip.calculate_scale(cls_token, t)
37
+ # control_scale = clip.calculate_scale(cls_token, c)
38
+ # print(f"current global target scale: {cur_target_scale},",
39
+ # f" global control scale: {control_scale}")
40
+
41
+ if a is not None and a != "none":
42
+ a = [a] * v.shape[0]
43
+ a = clip.encode_text(a)
44
+ anchor_scale = clip.calculate_scale(cls_token, a)
45
+ dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale
46
+ # print(f"global anchor scale: {anchor_scale}")
47
+
48
+ c_map = clip.calculate_scale(v, c)
49
+ a_map = clip.calculate_scale(v, a)
50
+ pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
51
+ base = 1 if enhance else 0
52
+ v = v + (pwm + base * a_map) * (t - a)
53
+ else:
54
+ dscale = target_scale - cur_target_scale
55
+ c_map = clip.calculate_scale(v, c)
56
+ pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
57
+ v = v + pwm * t
58
+ v = torch.cat([cls_token, v], dim=1)
59
+ return v
60
+
61
+ def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]):
62
+ """
63
+ v: visual tokens in shape (b, n, c)
64
+ target: target text embeddings in shape (b, 1 ,c)
65
+ control: control text embeddings in shape (b, 1, c)
66
+ """
67
+ controls, targets = clip.encode_text(controls + targets).chunk(2)
68
+ for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list):
69
+ v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds)
70
+ return v
71
+
72
+
73
+ def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False):
74
+ if a is not None and a != "none":
75
+ a = [a] * v.shape[0]
76
+ a = clip.encode_text(a)
77
+ if enhance:
78
+ s_a = clip.calculate_scale(v, a)
79
+ v = v - s_a * a
80
+ else:
81
+ v = v + target_scale * (t - a)
82
+ return v
83
+ if enhance:
84
+ v = v + target_scale * t
85
+ else:
86
+ cur_target_scale = clip.calculate_scale(v, t)
87
+ v = v + (target_scale - cur_target_scale) * t
88
+ return v
89
+
90
+
91
+ def global_manipulate(clip, v, targets, target_scales, anchors, enhances):
92
+ targets = clip.encode_text(targets)
93
+ for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances):
94
+ v = global_manipulate_step(clip, v, t, s_t, a, enhance)
95
+ return v
96
+
97
+
98
+ def assign_heatmap(s: torch.Tensor, threshold: float):
99
+ """
100
+ The shape of input scales tensor should be (b, n, 1)
101
+ """
102
+ maxm = s.max(dim=1, keepdim=True).values
103
+ minm = s.min(dim=1, keepdim=True).values
104
+ d = maxm - minm
105
+ return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25)
106
+
107
+
108
+ def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3,
109
+ controls, targets, anchors, thresholds_list, target_scales, enhances):
110
+ model.low_vram_shift("cond")
111
+ clip = model.cond_stage_model
112
+
113
+ v = clip.encode(reference, "full")
114
+ if len(targets) > 0:
115
+ controls, targets = clip.encode_text(controls + targets).chunk(2)
116
+ inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances)
117
+ for c, t, a, target_scale, thresholds, enhance in inputs_iter:
118
+ # update image tokens
119
+ v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds)
120
+ token_length = v.shape[1] - 1
121
+ grid_num = int(token_length ** 0.5)
122
+ vis_c = clip.encode_text([vis_c])
123
+ local_v = v[:, 1:]
124
+ scale = clip.calculate_scale(local_v, vis_c)
125
+ scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num)
126
+ scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width)
127
+
128
+ # calculate heatmaps
129
+ heatmaps = []
130
+ for threshold in [ts0, ts1, ts2, ts3]:
131
+ heatmap = assign_heatmap(scale, threshold=threshold)
132
+ heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy()
133
+ heatmap = (heatmap * 255.).astype(np.uint8)
134
+ heatmaps.append(heatmap)
135
+ return heatmaps
refnet/sampling/sampler.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ import k_diffusion
4
+ import inspect
5
+
6
+ from types import SimpleNamespace
7
+ from refnet.util import default
8
+ from .scheduler import schedulers, schedulers_map
9
+ from .denoiser import CFGDenoiser
10
+
11
+ defaults = SimpleNamespace(**{
12
+ "eta_ddim": 0.0,
13
+ "eta_ancestral": 1.0,
14
+ "ddim_discretize": "uniform",
15
+ "s_churn": 0.0,
16
+ "s_tmin": 0.0,
17
+ "s_noise": 1.0,
18
+ "k_sched_type": "Automatic",
19
+ "sigma_min": 0.0,
20
+ "sigma_max": 0.0,
21
+ "rho": 0.0,
22
+ "eta_noise_seed_delta": 0,
23
+ "always_discard_next_to_last_sigma": False,
24
+ })
25
+
26
+ @dataclasses.dataclass
27
+ class Sampler:
28
+ label: str
29
+ funcname: str
30
+ aliases: any
31
+ options: dict
32
+
33
+
34
+ samplers_k_diffusion = [
35
+ Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
36
+ Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
37
+ Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
38
+ Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
39
+ Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
40
+ Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
41
+ Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
42
+ Sampler('Euler', 'sample_euler', ['k_euler'], {}),
43
+ Sampler('LMS', 'sample_lms', ['k_lms'], {}),
44
+ Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
45
+ Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
46
+ Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
47
+ Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
48
+ Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True})
49
+ ]
50
+
51
+ sampler_extra_params = {
52
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
53
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
54
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
55
+ 'sample_dpm_fast': ['s_noise'],
56
+ 'sample_dpm_2_ancestral': ['s_noise'],
57
+ 'sample_dpmpp_2s_ancestral': ['s_noise'],
58
+ 'sample_dpmpp_sde': ['s_noise'],
59
+ 'sample_dpmpp_2m_sde': ['s_noise'],
60
+ 'sample_dpmpp_3m_sde': ['s_noise'],
61
+ }
62
+
63
+ def kdiffusion_sampler_list():
64
+ return [k.label for k in samplers_k_diffusion]
65
+
66
+
67
+ k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion}
68
+ k_diffusion_scheduler = {x.name: x.function for x in schedulers}
69
+
70
+ def exists(v):
71
+ return v is not None
72
+
73
+
74
+ class KDiffusionSampler:
75
+ def __init__(self, sampler, scheduler, sd, device):
76
+ # k_diffusion_samplers_map[]
77
+ self.config = k_diffusion_samplers_map[sampler]
78
+ funcname = self.config.funcname
79
+
80
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname)
81
+ self.scheduler_name = scheduler
82
+ self.sd = CFGDenoiser(sd, device)
83
+ self.model_wrap = self.sd.model_wrap
84
+ self.device = device
85
+
86
+ self.s_min_uncond = None
87
+ self.s_churn = 0.0
88
+ self.s_tmin = 0.0
89
+ self.s_tmax = float('inf')
90
+ self.s_noise = 1.0
91
+
92
+ self.eta_option_field = 'eta_ancestral'
93
+ self.eta_infotext_field = 'Eta'
94
+ self.eta_default = 1.0
95
+ self.eta = None
96
+
97
+ self.extra_params = []
98
+
99
+ if exists(sd.sigma_max) and exists(sd.sigma_min):
100
+ self.model_wrap.sigmas[-1] = sd.sigma_max
101
+ self.model_wrap.sigmas[0] = sd.sigma_min
102
+
103
+ def initialize(self):
104
+ self.eta = getattr(defaults, self.eta_option_field, 0.0)
105
+
106
+ extra_params_kwargs = {}
107
+ for param_name in self.extra_params:
108
+ if param_name in inspect.signature(self.func).parameters:
109
+ extra_params_kwargs[param_name] = getattr(self, param_name)
110
+
111
+ if 'eta' in inspect.signature(self.func).parameters:
112
+ extra_params_kwargs['eta'] = self.eta
113
+
114
+ if len(self.extra_params) > 0:
115
+ s_churn = getattr(defaults, 's_churn', self.s_churn)
116
+ s_tmin = getattr(defaults, 's_tmin', self.s_tmin)
117
+ s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf
118
+ s_noise = getattr(defaults, 's_noise', self.s_noise)
119
+
120
+ if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
121
+ extra_params_kwargs['s_churn'] = s_churn
122
+ self.s_churn = s_churn
123
+ if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
124
+ extra_params_kwargs['s_tmin'] = s_tmin
125
+ self.s_tmin = s_tmin
126
+ if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
127
+ extra_params_kwargs['s_tmax'] = s_tmax
128
+ self.s_tmax = s_tmax
129
+ if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
130
+ extra_params_kwargs['s_noise'] = s_noise
131
+ self.s_noise = s_noise
132
+
133
+ return extra_params_kwargs
134
+
135
+ def create_noise_sampler(self, x, sigmas, seed):
136
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
137
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
138
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
139
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
140
+
141
+ def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None):
142
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
143
+
144
+ steps += 1 if discard_next_to_last_sigma else 0
145
+
146
+ if self.scheduler_name == 'Automatic':
147
+ self.scheduler_name = self.config.options.get('scheduler', None)
148
+
149
+ scheduler = schedulers_map.get(self.scheduler_name)
150
+ sigma_min = default(sigmas_min, self.model_wrap.sigma_min)
151
+ sigma_max = default(sigmas_max, self.model_wrap.sigma_max)
152
+
153
+ if scheduler is None or scheduler.function is None:
154
+ sigmas = self.model_wrap.get_sigmas(steps)
155
+ else:
156
+ sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
157
+
158
+ if scheduler.need_inner_model:
159
+ sigmas_kwargs['inner_model'] = self.model_wrap
160
+
161
+ sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device)
162
+
163
+ if discard_next_to_last_sigma:
164
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
165
+
166
+ return sigmas
167
+
168
+
169
+ def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None):
170
+ x = x * sigmas[0]
171
+
172
+ extra_params_kwargs = self.initialize()
173
+ parameters = inspect.signature(self.func).parameters
174
+
175
+ if 'n' in parameters:
176
+ extra_params_kwargs['n'] = steps
177
+
178
+ if 'sigma_min' in parameters:
179
+ extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min()
180
+ extra_params_kwargs['sigma_max'] = sigmas.max()
181
+
182
+ if 'sigmas' in parameters:
183
+ extra_params_kwargs['sigmas'] = sigmas
184
+
185
+ if self.config.options.get('brownian_noise', False):
186
+ noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None
187
+ extra_params_kwargs['noise_sampler'] = noise_sampler
188
+
189
+ if self.config.options.get('solver_type', None) == 'heun':
190
+ extra_params_kwargs['solver_type'] = 'heun'
191
+
192
+ return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs)
refnet/sampling/scheduler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import k_diffusion
3
+ import dataclasses
4
+
5
+ @dataclasses.dataclass
6
+ class Scheduler:
7
+ name: str
8
+ label: str
9
+ function: any
10
+
11
+ default_rho: float = -1
12
+ need_inner_model: bool = False
13
+ aliases: list = None
14
+
15
+
16
+ def uniform(n, sigma_min, sigma_max, inner_model, device):
17
+ return inner_model.get_sigmas(n)
18
+
19
+
20
+ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
21
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
22
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
23
+ sigs = [
24
+ inner_model.t_to_sigma(ts)
25
+ for ts in torch.linspace(start, end, n + 1)[:-1]
26
+ ]
27
+ sigs += [0.0]
28
+ return torch.FloatTensor(sigs).to(device)
29
+
30
+ schedulers = [
31
+ Scheduler('automatic', 'Automatic', None),
32
+ Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
33
+ Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
34
+ Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
35
+ Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
36
+ Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
37
+ ]
38
+
39
+ def get_noise_schedulers():
40
+ return [scheduler.label for scheduler in schedulers]
41
+
42
+ schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
refnet/sampling/tps_transformation.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Calculate warped image using control point manipulation on a thin plate (TPS)
3
+ Based on Herve Lombaert's 2006 web article
4
+ "Manual Registration with Thin Plates"
5
+ (https://profs.etsmtl.ca/hlombaert/thinplates/)
6
+
7
+ Implementation by Yucheol Jung <ycjung@postech.ac.kr>
8
+ '''
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import PIL.Image as Image
13
+ import torchvision.transforms as tf
14
+
15
+
16
+ def tps_warp(images, num_points=10, perturbation_strength=10, random=True, pts_before=None, pts_after=None):
17
+ if random:
18
+ b, c, h, w = images.shape
19
+ device, dtype = images.device, images.dtype
20
+ pts_before = torch.rand([b, num_points, 2], dtype=dtype, device=device) * torch.Tensor([[[h, w]]]).to(device)
21
+ pts_after = pts_before + torch.randn([b, num_points, 2], dtype=dtype, device=device) * perturbation_strength
22
+ return _tps_warp(images, pts_before, pts_after)
23
+
24
+ def _tps_warp(im, pts_before, pts_after, normalize=True):
25
+ '''
26
+ Deforms image according to movement of pts_before and pts_after
27
+
28
+ Args)
29
+ im torch.Tensor object of size NxCxHxW
30
+ pts_before torch.Tensor object of size NxTx2 (T is # control pts)
31
+ pts_after torch.Tensor object of size NxTx2 (T is # control pts)
32
+ '''
33
+ # check input requirements
34
+ assert (4 == im.dim())
35
+ assert (3 == pts_after.dim())
36
+ assert (3 == pts_before.dim())
37
+ N = im.size()[0]
38
+ assert (N == pts_after.size()[0] and N == pts_before.size()[0])
39
+ assert (2 == pts_after.size()[2] and 2 == pts_before.size()[2])
40
+ T = pts_after.size()[1]
41
+ assert (T == pts_before.size()[1])
42
+ H = im.size()[2]
43
+ W = im.size()[3]
44
+
45
+ if normalize:
46
+ pts_after = pts_after.clone()
47
+ pts_after[:, :, 0] /= 0.5 * W
48
+ pts_after[:, :, 1] /= 0.5 * H
49
+ pts_after -= 1
50
+ pts_before = pts_before.clone()
51
+ pts_before[:, :, 0] /= 0.5 * W
52
+ pts_before[:, :, 1] /= 0.5 * H
53
+ pts_before -= 1
54
+
55
+ def construct_P():
56
+ '''
57
+ Consturcts matrix P of size NxTx3 where
58
+ P[n,i,0] := 1
59
+ P[n,i,1:] := pts_after[n]
60
+ '''
61
+ # Create matrix P with same configuration as 'pts_after'
62
+ P = pts_after.new_zeros((N, T, 3))
63
+ P[:, :, 0] = 1
64
+ P[:, :, 1:] = pts_after
65
+
66
+ return P
67
+
68
+ def calc_U(pt1, pt2):
69
+ '''
70
+ Calculate distance U between pt1 and pt2
71
+
72
+ U(r) := r**2 * log(r)
73
+ where
74
+ r := |pt1 - pt2|_2
75
+
76
+ Args)
77
+ pt1 torch.Tensor object, last dim is always 2
78
+ pt2 torch.Tensor object, last dim is always 2
79
+ '''
80
+ assert (2 == pt1.size()[-1])
81
+ assert (2 == pt2.size()[-1])
82
+
83
+ diff = pt1 - pt2
84
+ sq_diff = diff ** 2
85
+ sq_diff_sum = sq_diff.sum(-1)
86
+ r = sq_diff_sum.sqrt()
87
+
88
+ # Adds 1e-6 for numerical stability
89
+ return (r ** 2) * torch.log(r + 1e-6)
90
+
91
+ def construct_K():
92
+ '''
93
+ Consturcts matrix K of size NxTxT where
94
+ K[n,i,j] := U(|pts_after[n,i] - pts_after[n,j]|_2)
95
+ '''
96
+
97
+ # Assuming the number of control points are small enough,
98
+ # We just use for-loop for easy-to-read code
99
+
100
+ # Create matrix K with same configuration as 'pts_after'
101
+ K = pts_after.new_zeros((N, T, T))
102
+ for i in range(T):
103
+ for j in range(T):
104
+ K[:, i, j] = calc_U(pts_after[:, i, :], pts_after[:, j, :])
105
+
106
+ return K
107
+
108
+ def construct_L():
109
+ '''
110
+ Consturcts matrix L of size Nx(T+3)x(T+3) where
111
+ L[n] = [[ K[n] P[n] ]]
112
+ [[ P[n]^T 0 ]]
113
+ '''
114
+ P = construct_P()
115
+ K = construct_K()
116
+
117
+ # Create matrix L with same configuration as 'K'
118
+ L = K.new_zeros((N, T + 3, T + 3))
119
+
120
+ # Fill L matrix
121
+ L[:, :T, :T] = K
122
+ L[:, :T, T:(T + 3)] = P
123
+ L[:, T:(T + 3), :T] = P.transpose(1, 2)
124
+
125
+ return L
126
+
127
+ def construct_uv_grid():
128
+ '''
129
+ Returns H x W x 2 tensor uv with UV coordinate as its elements
130
+ uv[:,:,0] is H x W grid of x values
131
+ uv[:,:,1] is H x W grid of y values
132
+ '''
133
+ u_range = torch.arange(
134
+ start=-1.0, end=1.0, step=2.0 / W, device=im.device)
135
+ assert (W == u_range.size()[0])
136
+ u = u_range.new_zeros((H, W))
137
+ u[:] = u_range
138
+
139
+ v_range = torch.arange(
140
+ start=-1.0, end=1.0, step=2.0 / H, device=im.device)
141
+ assert (H == v_range.size()[0])
142
+ vt = v_range.new_zeros((W, H))
143
+ vt[:] = v_range
144
+ v = vt.transpose(0, 1)
145
+
146
+ return torch.stack([u, v], dim=2)
147
+
148
+ L = construct_L()
149
+ VT = pts_before.new_zeros((N, T + 3, 2))
150
+ # Use delta x and delta y as known heights of the surface
151
+ VT[:, :T, :] = pts_before - pts_after
152
+
153
+ # Solve Lx = VT
154
+ # x is of shape (N, T+3, 2)
155
+ # x[:,:,0] represents surface parameters for dx surface
156
+ # (dx values as surface height (z))
157
+ # x[:,:,1] represents surface parameters for dy surface
158
+ # (dy values as surface height (z))
159
+ x = torch.linalg.solve(L, VT)
160
+
161
+ uv = construct_uv_grid()
162
+ uv_batch = uv.repeat((N, 1, 1, 1))
163
+
164
+ def calc_dxdy():
165
+ '''
166
+ Calculate surface height for each uv coordinate
167
+
168
+ Returns NxHxWx2 tensor
169
+ '''
170
+
171
+ # control points of size NxTxHxWx2
172
+ cp = uv.new_zeros((H, W, N, T, 2))
173
+ cp[:, :, :] = pts_after
174
+ cp = cp.permute([2, 3, 0, 1, 4])
175
+
176
+ U = calc_U(uv, cp) # U value matrix of size NxTxHxW
177
+ w, a = x[:, :T, :], x[:, T:, :] # w is of size NxTx2, a is of size Nx3x2
178
+ w_x, w_y = w[:, :, 0], w[:, :, 1] # NxT each
179
+ a_x, a_y = a[:, :, 0], a[:, :, 1] # Nx3 each
180
+ dx = (
181
+ a_x[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
182
+ torch.einsum('nhwd,nd->nhw', uv_batch, a_x[:, 1:]) +
183
+ torch.einsum('nthw,nt->nhw', U, w_x)) # dx values of NxHxW
184
+ dy = (
185
+ a_y[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
186
+ torch.einsum('nhwd,nd->nhw', uv_batch, a_y[:, 1:]) +
187
+ torch.einsum('nthw,nt->nhw', U, w_y)) # dy values of NxHxW
188
+
189
+ return torch.stack([dx, dy], dim=3)
190
+
191
+ dxdy = calc_dxdy()
192
+ flow_field = uv + dxdy
193
+
194
+ return F.grid_sample(im, flow_field.to(im.dtype))
195
+
196
+ if __name__ == '__main__':
197
+ num_points = 10
198
+ perturbation_strength = 10
199
+ img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB")).unsqueeze(0)
200
+ # img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB").resize((224, 224))).unsqueeze(0)
201
+ img = tps_warp(img, num_points= num_points, perturbation_strength = perturbation_strength).squeeze(0)
202
+ img = tf.ToPILImage()(img)
203
+ img.show()