Spaces:
Running on Zero
Running on Zero
Commit ·
d066167
0
Parent(s):
initialize huggingface space demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +227 -0
- backend/__init__.py +16 -0
- backend/appfunc.py +298 -0
- backend/functool.py +276 -0
- backend/style.css +181 -0
- configs/inference/sdxl.yaml +88 -0
- configs/inference/xlv2.yaml +108 -0
- configs/scheduler_cfgs/ddim.yaml +10 -0
- configs/scheduler_cfgs/dpm.yaml +8 -0
- configs/scheduler_cfgs/dpm_sde.yaml +9 -0
- configs/scheduler_cfgs/lms.yaml +9 -0
- configs/scheduler_cfgs/pndm.yaml +10 -0
- k_diffusion/__init__.py +8 -0
- k_diffusion/external.py +181 -0
- k_diffusion/sampling.py +702 -0
- k_diffusion/utils.py +457 -0
- ldm/modules/diffusionmodules/__init__.py +0 -0
- ldm/modules/diffusionmodules/model.py +488 -0
- ldm/modules/distributions/__init__.py +0 -0
- ldm/modules/distributions/distributions.py +92 -0
- preprocessor/__init__.py +124 -0
- preprocessor/anime2sketch.py +119 -0
- preprocessor/anime_segment.py +487 -0
- preprocessor/manga_line_extractor.py +187 -0
- preprocessor/sk_model.py +94 -0
- preprocessor/sketchKeras.py +153 -0
- refnet/__init__.py +0 -0
- refnet/ldm/__init__.py +1 -0
- refnet/ldm/ddpm.py +236 -0
- refnet/ldm/openaimodel.py +386 -0
- refnet/ldm/util.py +289 -0
- refnet/modules/__init__.py +34 -0
- refnet/modules/attention.py +309 -0
- refnet/modules/attn_utils.py +155 -0
- refnet/modules/embedder.py +489 -0
- refnet/modules/encoder.py +224 -0
- refnet/modules/layers.py +99 -0
- refnet/modules/lora.py +370 -0
- refnet/modules/proj.py +142 -0
- refnet/modules/reference_net.py +430 -0
- refnet/modules/transformer.py +232 -0
- refnet/modules/unet.py +421 -0
- refnet/modules/unet_old.py +596 -0
- refnet/sampling/__init__.py +11 -0
- refnet/sampling/denoiser.py +181 -0
- refnet/sampling/hook.py +257 -0
- refnet/sampling/manipulation.py +135 -0
- refnet/sampling/sampler.py +192 -0
- refnet/sampling/scheduler.py +42 -0
- refnet/sampling/tps_transformation.py +203 -0
app.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
from refnet.sampling import get_noise_schedulers, get_sampler_list
|
| 5 |
+
from functools import partial
|
| 6 |
+
from backend import *
|
| 7 |
+
|
| 8 |
+
links = {
|
| 9 |
+
"base": "https://arxiv.org/abs/2401.01456",
|
| 10 |
+
"v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html",
|
| 11 |
+
"v1.5": "https://arxiv.org/abs/2502.19937v1",
|
| 12 |
+
"v2": "https://arxiv.org/abs/2504.06895",
|
| 13 |
+
"xl": "https://arxiv.org/abs/2601.04883",
|
| 14 |
+
"weights": "https://huggingface.co/tellurion/colorizer/tree/main",
|
| 15 |
+
"github": "https://github.com/tellurion-kanata/colorizeDiffusion",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def app_options():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0")
|
| 21 |
+
parser.add_argument("--server_port", '-port', type=int, default=7860)
|
| 22 |
+
parser.add_argument("--share", action="store_true")
|
| 23 |
+
parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true")
|
| 24 |
+
return parser.parse_args()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_interface(opt, *args, **kwargs) -> None:
|
| 28 |
+
sampler_list = get_sampler_list()
|
| 29 |
+
scheduler_list = get_noise_schedulers()
|
| 30 |
+
|
| 31 |
+
img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png")
|
| 32 |
+
with gr.Blocks(
|
| 33 |
+
title = "Colorize Diffusion",
|
| 34 |
+
css_paths = "backend/style.css",
|
| 35 |
+
theme = gr.themes.Ocean(),
|
| 36 |
+
elem_id = "main-interface",
|
| 37 |
+
analytics_enabled = False,
|
| 38 |
+
fill_width = True
|
| 39 |
+
) as block:
|
| 40 |
+
with gr.Row(elem_id="header-row", equal_height=True, variant="panel"):
|
| 41 |
+
gr.Markdown(f"""<div class="header-container">
|
| 42 |
+
<div class="app-header"><span class="emoji">🎨</span><span class="title-text">Colorize Diffusion</span></div>
|
| 43 |
+
<div class="paper-links-icons">
|
| 44 |
+
<a href="{links['base']}" target="_blank">
|
| 45 |
+
<img src="https://img.shields.io/badge/arXiv-2407.15886 (base)-B31B1B?style=flat&logo=arXiv" alt="arXiv Paper">
|
| 46 |
+
</a>
|
| 47 |
+
<a href="{links['v1']}" target="_blank">
|
| 48 |
+
<img src="https://img.shields.io/badge/WACV 2025-v1-0CA4A5?style=flat&logo=Semantic%20Web" alt="WACV 2025">
|
| 49 |
+
</a>
|
| 50 |
+
<a href="{links['v1.5']}" target="_blank">
|
| 51 |
+
<img src="https://img.shields.io/badge/CVPR 2025-v1.5-0CA4A5?style=flat&logo=Semantic%20Web" alt="CVPR 2025">
|
| 52 |
+
</a>
|
| 53 |
+
<a href="{links['v2']}" target="_blank">
|
| 54 |
+
<img src="https://img.shields.io/badge/arXiv-2504.06895 (v2)-B31B1B?style=flat&logo=arXiv" alt="arXiv v2 Paper">
|
| 55 |
+
</a>
|
| 56 |
+
<a href="{links['weights']}" target="_blank">
|
| 57 |
+
<img src="https://img.shields.io/badge/Hugging%20Face-Model%20Weights-FF9D00?style=flat&logo=Hugging%20Face" alt="Model Weights">
|
| 58 |
+
</a>
|
| 59 |
+
<a href="{links['github']}" target="_blank">
|
| 60 |
+
<img src="https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub" alt="GitHub">
|
| 61 |
+
</a>
|
| 62 |
+
<a href="https://github.com/tellurion-kanata/colorizeDiffusion/blob/master/LICENSE" target="_blank">
|
| 63 |
+
<img src="https://img.shields.io/badge/License-CC--BY--NC--SA%204.0-4CAF50?style=flat&logo=Creative%20Commons" alt="License">
|
| 64 |
+
</a>
|
| 65 |
+
</div>
|
| 66 |
+
</div>""")
|
| 67 |
+
|
| 68 |
+
with gr.Row(elem_id="content-row", equal_height=False, variant="panel"):
|
| 69 |
+
with gr.Column():
|
| 70 |
+
with gr.Row(visible=opt.enable_text_manipulation):
|
| 71 |
+
target = gr.Textbox(label="Target prompt", value="", scale=2)
|
| 72 |
+
anchor = gr.Textbox(label="Anchor prompt", value="", scale=2)
|
| 73 |
+
control = gr.Textbox(label="Control prompt", value="", scale=2)
|
| 74 |
+
with gr.Row(visible=opt.enable_text_manipulation):
|
| 75 |
+
target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2)
|
| 76 |
+
ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01)
|
| 77 |
+
ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01)
|
| 78 |
+
ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01)
|
| 79 |
+
ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01)
|
| 80 |
+
with gr.Row(visible=opt.enable_text_manipulation):
|
| 81 |
+
enhance = gr.Checkbox(label="Enhance manipulation", value=False)
|
| 82 |
+
add_prompt = gr.Button(value="Add")
|
| 83 |
+
clear_prompt = gr.Button(value="Clear")
|
| 84 |
+
vis_button = gr.Button(value="Visualize")
|
| 85 |
+
text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation)
|
| 86 |
+
|
| 87 |
+
with gr.Row():
|
| 88 |
+
sketch_img = img_block(label="Sketch")
|
| 89 |
+
reference_img = img_block(label="Reference")
|
| 90 |
+
background_img = img_block(label="Background")
|
| 91 |
+
|
| 92 |
+
style_enhance = gr.State(False)
|
| 93 |
+
fg_enhance = gr.State(False)
|
| 94 |
+
with gr.Row():
|
| 95 |
+
bg_enhance = gr.Checkbox(label="Low-level injection", value=False)
|
| 96 |
+
injection = gr.Checkbox(label="Attention injection", value=False)
|
| 97 |
+
autofit_size = gr.Checkbox(label="Autofit size", value=False)
|
| 98 |
+
with gr.Row():
|
| 99 |
+
gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5)
|
| 100 |
+
strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05)
|
| 101 |
+
fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05)
|
| 102 |
+
bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05)
|
| 103 |
+
with gr.Row():
|
| 104 |
+
gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1)
|
| 105 |
+
ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05)
|
| 106 |
+
mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05)
|
| 107 |
+
merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05)
|
| 108 |
+
with gr.Row():
|
| 109 |
+
bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1)
|
| 110 |
+
width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2)
|
| 111 |
+
with gr.Row():
|
| 112 |
+
step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1)
|
| 113 |
+
height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2)
|
| 114 |
+
|
| 115 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1)
|
| 116 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 117 |
+
with gr.Row():
|
| 118 |
+
crop = gr.Checkbox(label="Crop result", value=False, scale=1)
|
| 119 |
+
remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2)
|
| 120 |
+
rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2)
|
| 121 |
+
latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2)
|
| 122 |
+
with gr.Row():
|
| 123 |
+
injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0,
|
| 124 |
+
maximum=2.0, value=0, step=0.05)
|
| 125 |
+
injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0,
|
| 126 |
+
maximum=1.0, value=0.5, step=0.05)
|
| 127 |
+
injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0,
|
| 128 |
+
value=0, step=0.05)
|
| 129 |
+
|
| 130 |
+
with gr.Row():
|
| 131 |
+
reuse_seed = gr.Button(value="Reuse Seed")
|
| 132 |
+
random_seed = gr.Button(value="Random Seed")
|
| 133 |
+
|
| 134 |
+
with gr.Column():
|
| 135 |
+
result_gallery = gr.Gallery(
|
| 136 |
+
label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png"
|
| 137 |
+
)
|
| 138 |
+
run_button = gr.Button("Generate", variant="primary", size="lg")
|
| 139 |
+
with gr.Row():
|
| 140 |
+
mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01)
|
| 141 |
+
mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01)
|
| 142 |
+
pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05)
|
| 143 |
+
|
| 144 |
+
with gr.Row():
|
| 145 |
+
sd_model = gr.Dropdown(choices=get_available_models(), label="Models",
|
| 146 |
+
value=get_available_models()[0])
|
| 147 |
+
extractor_model = gr.Dropdown(choices=line_extractor_list,
|
| 148 |
+
label="Line extractor", value=default_line_extractor)
|
| 149 |
+
mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor",
|
| 150 |
+
value=default_mask_extractor)
|
| 151 |
+
with gr.Row():
|
| 152 |
+
sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler")
|
| 153 |
+
scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler")
|
| 154 |
+
preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"],
|
| 155 |
+
label="Sketch preprocessor", value="invert")
|
| 156 |
+
|
| 157 |
+
with gr.Row():
|
| 158 |
+
deterministic = gr.Checkbox(label="Deterministic batch seed", value=False)
|
| 159 |
+
save_memory = gr.Checkbox(label="Save memory", value=True)
|
| 160 |
+
|
| 161 |
+
# Hidden states for unused advanced controls
|
| 162 |
+
fg_disentangle_scale = gr.State(1.0)
|
| 163 |
+
start_step = gr.State(0.0)
|
| 164 |
+
end_step = gr.State(1.0)
|
| 165 |
+
no_start_step = gr.State(-0.05)
|
| 166 |
+
no_end_step = gr.State(-0.05)
|
| 167 |
+
return_inter = gr.State(False)
|
| 168 |
+
accurate = gr.State(False)
|
| 169 |
+
enc_scale = gr.State(1.0)
|
| 170 |
+
middle_scale = gr.State(1.0)
|
| 171 |
+
low_scale = gr.State(1.0)
|
| 172 |
+
ctl_scale_1 = gr.State(1.0)
|
| 173 |
+
ctl_scale_2 = gr.State(1.0)
|
| 174 |
+
ctl_scale_3 = gr.State(1.0)
|
| 175 |
+
ctl_scale_4 = gr.State(1.0)
|
| 176 |
+
|
| 177 |
+
add_prompt.click(fn=apppend_prompt,
|
| 178 |
+
inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt],
|
| 179 |
+
outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt])
|
| 180 |
+
clear_prompt.click(fn=clear_prompts, outputs=[text_prompt])
|
| 181 |
+
|
| 182 |
+
reuse_seed.click(fn=get_last_seed, outputs=[seed])
|
| 183 |
+
random_seed.click(fn=reset_random_seed, outputs=[seed])
|
| 184 |
+
|
| 185 |
+
extractor_model.input(fn=switch_extractor, inputs=[extractor_model])
|
| 186 |
+
sd_model.input(fn=load_model, inputs=[sd_model])
|
| 187 |
+
mask_model.input(fn=switch_mask_extractor, inputs=[mask_model])
|
| 188 |
+
|
| 189 |
+
ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
|
| 190 |
+
bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
|
| 191 |
+
ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale,
|
| 192 |
+
mask_scale, height, width, seed, save_memory, step, injection, autofit_size,
|
| 193 |
+
remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step,
|
| 194 |
+
crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler,
|
| 195 |
+
preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3,
|
| 196 |
+
enhance, accurate, enc_scale, middle_scale, low_scale, strength]
|
| 197 |
+
|
| 198 |
+
run_button.click(
|
| 199 |
+
fn = inference,
|
| 200 |
+
inputs = ips,
|
| 201 |
+
outputs = [result_gallery],
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
vis_button.click(
|
| 205 |
+
fn = visualize,
|
| 206 |
+
inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3],
|
| 207 |
+
outputs = [result_gallery],
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
block.launch(
|
| 211 |
+
server_name = opt.server_name,
|
| 212 |
+
share = opt.share,
|
| 213 |
+
server_port = opt.server_port,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == '__main__':
|
| 218 |
+
opt = app_options()
|
| 219 |
+
try:
|
| 220 |
+
models = get_available_models()
|
| 221 |
+
load_model(models[0])
|
| 222 |
+
switch_extractor(default_line_extractor)
|
| 223 |
+
switch_mask_extractor(default_mask_extractor)
|
| 224 |
+
interface = init_interface(opt)
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"Error initializing interface: {e}")
|
| 227 |
+
raise
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .appfunc import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'switch_extractor', 'switch_mask_extractor',
|
| 6 |
+
'get_available_models', 'load_model', 'inference', 'reset_random_seed', 'get_last_seed',
|
| 7 |
+
'apppend_prompt', 'clear_prompts', 'visualize',
|
| 8 |
+
'default_line_extractor', 'default_mask_extractor', 'MAXM_INT32',
|
| 9 |
+
'mask_extractor_list', 'line_extractor_list',
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
default_line_extractor = "lineart_keras"
|
| 14 |
+
default_mask_extractor = "rmbg-v2"
|
| 15 |
+
mask_extractor_list = ["none", "ISNet", "rmbg-v2", "BiRefNet", "BiRefNet_HR"]
|
| 16 |
+
line_extractor_list = ["lineart", "lineart_denoise", "lineart_keras", "lineart_sk"]
|
backend/appfunc.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import traceback
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import os.path as osp
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from refnet.util import instantiate_from_config
|
| 11 |
+
from preprocessor import create_model
|
| 12 |
+
from .functool import *
|
| 13 |
+
|
| 14 |
+
model = None
|
| 15 |
+
|
| 16 |
+
model_type = ""
|
| 17 |
+
current_checkpoint = ""
|
| 18 |
+
global_seed = None
|
| 19 |
+
|
| 20 |
+
smask_extractor = create_model("ISNet-sketch").cpu()
|
| 21 |
+
|
| 22 |
+
MAXM_INT32 = 429496729
|
| 23 |
+
|
| 24 |
+
# HuggingFace model repository
|
| 25 |
+
HF_REPO_ID = "tellurion/colorizer"
|
| 26 |
+
MODEL_CACHE_DIR = "models"
|
| 27 |
+
|
| 28 |
+
# Model registry: filename -> model_type
|
| 29 |
+
MODEL_REGISTRY = {
|
| 30 |
+
"sdxl.safetensors": "sdxl",
|
| 31 |
+
"xlv2.safetensors": "xlv2",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
model_types = ["sdxl", "xlv2"]
|
| 35 |
+
|
| 36 |
+
'''
|
| 37 |
+
Gradio UI functions
|
| 38 |
+
'''
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_available_models():
|
| 42 |
+
"""Return list of available model names from registry."""
|
| 43 |
+
return list(MODEL_REGISTRY.keys())
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_model(filename):
|
| 47 |
+
"""Download a model from HuggingFace Hub if not already cached."""
|
| 48 |
+
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
| 49 |
+
local_path = osp.join(MODEL_CACHE_DIR, filename)
|
| 50 |
+
if osp.exists(local_path):
|
| 51 |
+
return local_path
|
| 52 |
+
|
| 53 |
+
print(f"Downloading {filename} from {HF_REPO_ID}...")
|
| 54 |
+
gr.Info(f"Downloading {filename}...")
|
| 55 |
+
path = hf_hub_download(
|
| 56 |
+
repo_id=HF_REPO_ID,
|
| 57 |
+
filename=filename,
|
| 58 |
+
local_dir=MODEL_CACHE_DIR,
|
| 59 |
+
)
|
| 60 |
+
print(f"Downloaded to {path}")
|
| 61 |
+
return path
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def switch_extractor(type):
|
| 65 |
+
global line_extractor
|
| 66 |
+
try:
|
| 67 |
+
line_extractor = create_model(type)
|
| 68 |
+
gr.Info(f"Switched to {type} extractor")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error info: {e}")
|
| 71 |
+
print(traceback.print_exc())
|
| 72 |
+
gr.Info(f"Failed in loading {type} extractor")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def switch_mask_extractor(type):
|
| 76 |
+
global mask_extractor
|
| 77 |
+
try:
|
| 78 |
+
mask_extractor = create_model(type)
|
| 79 |
+
gr.Info(f"Switched to {type} extractor")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Error info: {e}")
|
| 82 |
+
print(traceback.print_exc())
|
| 83 |
+
gr.Info(f"Failed in loading {type} extractor")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt):
|
| 87 |
+
target = target.strip()
|
| 88 |
+
anchor = anchor.strip()
|
| 89 |
+
control = control.strip()
|
| 90 |
+
if target == "": target = "none"
|
| 91 |
+
if anchor == "": anchor = "none"
|
| 92 |
+
if control == "": control = "none"
|
| 93 |
+
new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; "
|
| 94 |
+
f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}")
|
| 95 |
+
return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def clear_prompts():
|
| 99 |
+
return ""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_model(ckpt_name):
|
| 103 |
+
global model, model_type, current_checkpoint
|
| 104 |
+
config_root = "configs/inference"
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
# Determine model type from registry or filename prefix
|
| 108 |
+
new_model_type = MODEL_REGISTRY.get(ckpt_name, "")
|
| 109 |
+
if not new_model_type:
|
| 110 |
+
for key in model_types:
|
| 111 |
+
if ckpt_name.startswith(key):
|
| 112 |
+
new_model_type = key
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if model_type != new_model_type or not "model" in globals():
|
| 116 |
+
if "model" in globals() and exists(model):
|
| 117 |
+
del model
|
| 118 |
+
config_path = osp.join(config_root, f"{new_model_type}.yaml")
|
| 119 |
+
new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
|
| 120 |
+
print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...")
|
| 121 |
+
model = new_model
|
| 122 |
+
|
| 123 |
+
# Download model from HF Hub
|
| 124 |
+
local_path = download_model(ckpt_name)
|
| 125 |
+
|
| 126 |
+
model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v"
|
| 127 |
+
model.init_from_ckpt(local_path, logging=True)
|
| 128 |
+
model.switch_to_fp16()
|
| 129 |
+
|
| 130 |
+
model_type = new_model_type
|
| 131 |
+
current_checkpoint = ckpt_name
|
| 132 |
+
print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].")
|
| 133 |
+
gr.Info("Loaded model successfully.")
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Error type: {e}")
|
| 137 |
+
print(traceback.print_exc())
|
| 138 |
+
gr.Info("Failed in loading model.")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_last_seed():
|
| 142 |
+
return global_seed or -1
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def reset_random_seed():
|
| 146 |
+
return -1
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def visualize(reference, text, *args):
|
| 150 |
+
return visualize_heatmaps(model, reference, parse_prompts(text), *args)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def set_cas_scales(accurate, cas_args):
|
| 154 |
+
enc_scale, middle_scale, low_scale, strength = cas_args[:4]
|
| 155 |
+
if not accurate:
|
| 156 |
+
scale_strength = {
|
| 157 |
+
"level_control": True,
|
| 158 |
+
"scales": {
|
| 159 |
+
"encoder": enc_scale * strength,
|
| 160 |
+
"middle": middle_scale * strength,
|
| 161 |
+
"low": low_scale * strength,
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
else:
|
| 165 |
+
scale_strength = {
|
| 166 |
+
"level_control": False,
|
| 167 |
+
"scales": list(cas_args[4:])
|
| 168 |
+
}
|
| 169 |
+
return scale_strength
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
def inference(
|
| 174 |
+
style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
|
| 175 |
+
bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
|
| 176 |
+
ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4,
|
| 177 |
+
fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step,
|
| 178 |
+
injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale,
|
| 179 |
+
start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess,
|
| 180 |
+
deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate,
|
| 181 |
+
*args
|
| 182 |
+
):
|
| 183 |
+
global global_seed, line_extractor, mask_extractor
|
| 184 |
+
global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32)
|
| 185 |
+
torch.manual_seed(global_seed)
|
| 186 |
+
|
| 187 |
+
# Auto-fit size based on sketch dimensions
|
| 188 |
+
if autofit_size and exists(input_s):
|
| 189 |
+
sketch_w, sketch_h = input_s.size
|
| 190 |
+
aspect_ratio = sketch_w / sketch_h
|
| 191 |
+
target_area = 1024 * 1024
|
| 192 |
+
new_h = int((target_area / aspect_ratio) ** 0.5)
|
| 193 |
+
new_w = int(new_h * aspect_ratio)
|
| 194 |
+
height = ((new_h + 16) // 32) * 32
|
| 195 |
+
width = ((new_w + 16) // 32) * 32
|
| 196 |
+
height = max(768, min(1536, height))
|
| 197 |
+
width = max(768, min(1536, width))
|
| 198 |
+
gr.Info(f"Auto-fitted size: {width}x{height}")
|
| 199 |
+
|
| 200 |
+
smask, rmask, bgmask = None, None, None
|
| 201 |
+
manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance)
|
| 202 |
+
inputs = preprocessing_inputs(
|
| 203 |
+
sketch = input_s,
|
| 204 |
+
reference = input_r,
|
| 205 |
+
background = input_bg,
|
| 206 |
+
preprocess = preprocess,
|
| 207 |
+
hook = injection,
|
| 208 |
+
resolution = (height, width),
|
| 209 |
+
extractor = line_extractor,
|
| 210 |
+
pad_scale = pad_scale,
|
| 211 |
+
)
|
| 212 |
+
sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs
|
| 213 |
+
|
| 214 |
+
cond = {"reference": reference, "sketch": sketch, "background": background}
|
| 215 |
+
mask_guided = bg_enhance or fg_enhance
|
| 216 |
+
|
| 217 |
+
if exists(white_sketch) and exists(reference) and mask_guided:
|
| 218 |
+
mask_extractor.cuda()
|
| 219 |
+
smask_extractor.cuda()
|
| 220 |
+
smask = smask_extractor.proceed(
|
| 221 |
+
x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if exists(background) and remove_fg:
|
| 225 |
+
bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True)
|
| 226 |
+
filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background))
|
| 227 |
+
cond.update({"background": filtered_background, "rmask": bgmask})
|
| 228 |
+
else:
|
| 229 |
+
rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True)
|
| 230 |
+
cond.update({"rmask": rmask})
|
| 231 |
+
rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask))
|
| 232 |
+
cond.update({"smask": smask})
|
| 233 |
+
smask_extractor.cpu()
|
| 234 |
+
mask_extractor.cpu()
|
| 235 |
+
|
| 236 |
+
scale_strength = set_cas_scales(accurate, args)
|
| 237 |
+
ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4]
|
| 238 |
+
ctl_scales = [t * ctl_scale for t in ctl_scales]
|
| 239 |
+
|
| 240 |
+
results = model.generate(
|
| 241 |
+
# Colorization mode
|
| 242 |
+
style_enhance = style_enhance,
|
| 243 |
+
bg_enhance = bg_enhance,
|
| 244 |
+
fg_enhance = fg_enhance,
|
| 245 |
+
fg_disentangle_scale = fg_disentangle_scale,
|
| 246 |
+
latent_inpaint = latent_inpaint,
|
| 247 |
+
|
| 248 |
+
# Conditional inputs
|
| 249 |
+
cond = cond,
|
| 250 |
+
ctl_scale = ctl_scales,
|
| 251 |
+
merge_scale = merge_scale,
|
| 252 |
+
mask_scale = mask_scale,
|
| 253 |
+
mask_thresh = mask_ts,
|
| 254 |
+
mask_thresh_sketch = mask_ss,
|
| 255 |
+
|
| 256 |
+
# Sampling settings
|
| 257 |
+
bs = bs,
|
| 258 |
+
gs = [gs_r, gs_s],
|
| 259 |
+
sampler = sampler,
|
| 260 |
+
scheduler = scheduler,
|
| 261 |
+
start_step = start_step,
|
| 262 |
+
end_step = end_step,
|
| 263 |
+
no_start_step = no_start_step,
|
| 264 |
+
no_end_step = no_end_step,
|
| 265 |
+
strength = scale_strength,
|
| 266 |
+
fg_strength = fg_strength,
|
| 267 |
+
bg_strength = bg_strength,
|
| 268 |
+
seed = global_seed,
|
| 269 |
+
deterministic = deterministic,
|
| 270 |
+
height = height,
|
| 271 |
+
width = width,
|
| 272 |
+
step = step,
|
| 273 |
+
|
| 274 |
+
# Injection settings
|
| 275 |
+
injection = injection,
|
| 276 |
+
injection_cfg = infid_r,
|
| 277 |
+
injection_control = infid_x,
|
| 278 |
+
injection_start_step = injstep,
|
| 279 |
+
hook_xr = inject_xr,
|
| 280 |
+
hook_xs = inject_xs,
|
| 281 |
+
|
| 282 |
+
# Additional settings
|
| 283 |
+
low_vram = low_vram,
|
| 284 |
+
return_intermediate = return_inter,
|
| 285 |
+
manipulation_params = manipulation_params,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if rmbg:
|
| 289 |
+
mask_extractor.cuda()
|
| 290 |
+
mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1)
|
| 291 |
+
results = torch.where(mask >= mask_ss, results, torch.ones_like(results))
|
| 292 |
+
mask_extractor.cpu()
|
| 293 |
+
|
| 294 |
+
results = postprocess(results, sketch, reference, background, crop, original_shape,
|
| 295 |
+
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss)
|
| 296 |
+
torch.cuda.empty_cache()
|
| 297 |
+
gr.Info("Generation completed.")
|
| 298 |
+
return results
|
backend/functool.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import PIL.Image as Image
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
maxium_resolution = 4096
|
| 12 |
+
token_length = int(256 ** 0.5)
|
| 13 |
+
|
| 14 |
+
def exists(v):
|
| 15 |
+
return v is not None
|
| 16 |
+
|
| 17 |
+
resize = partial(transforms.Resize, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
|
| 18 |
+
|
| 19 |
+
def resize_image(img, new_size, w, h):
|
| 20 |
+
if w > h:
|
| 21 |
+
img = resize((int(h / w * new_size), new_size))(img)
|
| 22 |
+
else:
|
| 23 |
+
img = resize((new_size, int(w / h * new_size)))(img)
|
| 24 |
+
return img
|
| 25 |
+
|
| 26 |
+
def pad_image(image: torch.Tensor, h, w):
|
| 27 |
+
b, c, height, width = image.shape
|
| 28 |
+
square_image = -torch.ones([b, c, h, w], device=image.device)
|
| 29 |
+
left = (w - width) // 2
|
| 30 |
+
top = (h - height) // 2
|
| 31 |
+
square_image[:, :, top:top+height, left:left+width] = image
|
| 32 |
+
|
| 33 |
+
return square_image, (left, top, width, height)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def pad_image_with_margin(image: Image, scale):
|
| 37 |
+
w, h = image.size
|
| 38 |
+
nw = int(w * scale)
|
| 39 |
+
bg = Image.new('RGB', (nw, h), (255, 255, 255))
|
| 40 |
+
bg.paste(image, ((nw-w)//2, 0))
|
| 41 |
+
return bg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def crop_image_from_square(square_image, original_dim):
|
| 45 |
+
left, top, width, height = original_dim
|
| 46 |
+
return square_image.crop((left, top, left + width, top + height))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def to_tensor(x, inverse=False):
|
| 50 |
+
x = transforms.ToTensor()(x).unsqueeze(0)
|
| 51 |
+
x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
|
| 52 |
+
return x if not inverse else -x
|
| 53 |
+
|
| 54 |
+
def to_numpy(x, denormalize=True):
|
| 55 |
+
if denormalize:
|
| 56 |
+
return ((x.clamp(-1, 1) + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
| 57 |
+
else:
|
| 58 |
+
return (x.clamp(0, 1) * 255)[0][0].cpu().numpy().astype(np.uint8)
|
| 59 |
+
|
| 60 |
+
def lineart_standard(x: Image.Image):
|
| 61 |
+
x = np.array(x).astype(np.float32)
|
| 62 |
+
g = cv2.GaussianBlur(x, (0, 0), 6.0)
|
| 63 |
+
intensity = np.min(g - x, axis=2).clip(0, 255)
|
| 64 |
+
intensity /= max(16, np.median(intensity[intensity > 8]))
|
| 65 |
+
intensity *= 127
|
| 66 |
+
intensity = np.repeat(np.expand_dims(intensity, 2), 3, axis=2)
|
| 67 |
+
result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
|
| 68 |
+
return result
|
| 69 |
+
|
| 70 |
+
def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new=False):
|
| 71 |
+
w, h = sketch.size
|
| 72 |
+
th, tw = resolution
|
| 73 |
+
r = min(th/h, tw/w)
|
| 74 |
+
|
| 75 |
+
if preprocess == "none":
|
| 76 |
+
sketch = to_tensor(sketch)
|
| 77 |
+
elif preprocess == "invert":
|
| 78 |
+
sketch = to_tensor(sketch, inverse=True)
|
| 79 |
+
elif preprocess == "invert-webui":
|
| 80 |
+
sketch = lineart_standard(sketch)
|
| 81 |
+
else:
|
| 82 |
+
sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
|
| 83 |
+
|
| 84 |
+
sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
|
| 85 |
+
if new:
|
| 86 |
+
sketch = ((sketch + 1) / 2.).clamp(0, 1)
|
| 87 |
+
white_sketch = 1 - sketch
|
| 88 |
+
else:
|
| 89 |
+
white_sketch = -sketch
|
| 90 |
+
return sketch, original_shape, white_sketch
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def preprocessing_inputs(
|
| 95 |
+
sketch: Image.Image,
|
| 96 |
+
reference: Image.Image,
|
| 97 |
+
background: Image.Image,
|
| 98 |
+
preprocess: str,
|
| 99 |
+
hook: bool,
|
| 100 |
+
resolution: tuple[int, int],
|
| 101 |
+
extractor: nn.Module,
|
| 102 |
+
pad_scale: float = 1.,
|
| 103 |
+
new = False
|
| 104 |
+
):
|
| 105 |
+
extractor = extractor.cuda()
|
| 106 |
+
h, w = resolution
|
| 107 |
+
if exists(sketch):
|
| 108 |
+
sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor, new)
|
| 109 |
+
else:
|
| 110 |
+
sketch = torch.zeros([1, 3, h, w], device="cuda") if new else -torch.ones([1, 3, h, w], device="cuda")
|
| 111 |
+
white_sketch = None
|
| 112 |
+
original_shape = (0, 0, h, w)
|
| 113 |
+
|
| 114 |
+
inject_xs = None
|
| 115 |
+
if hook:
|
| 116 |
+
assert exists(reference) and exists(extractor)
|
| 117 |
+
maxm = max(h, w)
|
| 118 |
+
# inject_xs = resize((h, w))(extractor.proceed(resize((maxm, maxm))(reference)).repeat(1, 3, 1, 1))
|
| 119 |
+
inject_xr = to_tensor(resize((h, w))(reference))
|
| 120 |
+
else:
|
| 121 |
+
inject_xr = None
|
| 122 |
+
extractor = extractor.cpu()
|
| 123 |
+
|
| 124 |
+
if exists(reference):
|
| 125 |
+
if pad_scale > 1.:
|
| 126 |
+
reference = pad_image_with_margin(reference, pad_scale)
|
| 127 |
+
reference = to_tensor(reference)
|
| 128 |
+
|
| 129 |
+
if exists(background):
|
| 130 |
+
if pad_scale > 1.:
|
| 131 |
+
background = pad_image_with_margin(background, pad_scale)
|
| 132 |
+
background = to_tensor(background)
|
| 133 |
+
|
| 134 |
+
return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
|
| 135 |
+
|
| 136 |
+
def postprocess(results, sketch, reference, background, crop, original_shape,
|
| 137 |
+
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss, new=False):
|
| 138 |
+
results = to_numpy(results)
|
| 139 |
+
sketch = to_numpy(sketch, not new)[0]
|
| 140 |
+
|
| 141 |
+
results_list = []
|
| 142 |
+
for result in results:
|
| 143 |
+
result = Image.fromarray(result)
|
| 144 |
+
if crop:
|
| 145 |
+
result = crop_image_from_square(result, original_shape)
|
| 146 |
+
results_list.append(result)
|
| 147 |
+
|
| 148 |
+
results_list.append(sketch)
|
| 149 |
+
|
| 150 |
+
if exists(reference):
|
| 151 |
+
reference = to_numpy(reference)[0]
|
| 152 |
+
results_list.append(reference)
|
| 153 |
+
# if vis_crossattn:
|
| 154 |
+
# results_list += visualize_attention_map(reference, results_list[0], vh, vw)
|
| 155 |
+
|
| 156 |
+
if exists(background):
|
| 157 |
+
background = to_numpy(background)[0]
|
| 158 |
+
results_list.append(background)
|
| 159 |
+
|
| 160 |
+
if exists(bgmask):
|
| 161 |
+
background = Image.fromarray(background)
|
| 162 |
+
results_list.append(Image.composite(
|
| 163 |
+
background,
|
| 164 |
+
Image.new("RGB", background.size, (255, 255, 255)),
|
| 165 |
+
Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
|
| 166 |
+
))
|
| 167 |
+
results_list.append(Image.composite(
|
| 168 |
+
Image.new("RGB", background.size, (255, 255, 255)),
|
| 169 |
+
background,
|
| 170 |
+
Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
|
| 171 |
+
))
|
| 172 |
+
|
| 173 |
+
if mask_guided:
|
| 174 |
+
smask[smask < mask_ss] = 0
|
| 175 |
+
results_list.append(Image.fromarray(to_numpy(smask, denormalize=False), mode="L"))
|
| 176 |
+
|
| 177 |
+
if exists(rmask):
|
| 178 |
+
reference = Image.fromarray(reference)
|
| 179 |
+
rmask[rmask < mask_ts] = 0
|
| 180 |
+
results_list.append(Image.fromarray(to_numpy(rmask, denormalize=False), mode="L"))
|
| 181 |
+
results_list.append(Image.composite(
|
| 182 |
+
reference,
|
| 183 |
+
Image.new("RGB", reference.size, (255, 255, 255)),
|
| 184 |
+
Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
|
| 185 |
+
))
|
| 186 |
+
results_list.append(Image.composite(
|
| 187 |
+
Image.new("RGB", reference.size, (255, 255, 255)),
|
| 188 |
+
reference,
|
| 189 |
+
Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
|
| 190 |
+
))
|
| 191 |
+
|
| 192 |
+
return results_list
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def parse_prompts(
|
| 196 |
+
prompts: str,
|
| 197 |
+
target: bool = None,
|
| 198 |
+
anchor: bool = None,
|
| 199 |
+
control: bool = None,
|
| 200 |
+
target_scale: bool = None,
|
| 201 |
+
ts0: float = None,
|
| 202 |
+
ts1: float = None,
|
| 203 |
+
ts2: float = None,
|
| 204 |
+
ts3: float = None,
|
| 205 |
+
enhance: bool = None
|
| 206 |
+
):
|
| 207 |
+
|
| 208 |
+
targets = []
|
| 209 |
+
anchors = []
|
| 210 |
+
controls = []
|
| 211 |
+
scales = []
|
| 212 |
+
enhances = []
|
| 213 |
+
thresholds_list = []
|
| 214 |
+
|
| 215 |
+
replace_str = ["; [anchor] ", "; [control] ", "; [scale]", "; [enhanced]", "; [ts0]", "; [ts1]", "; [ts2]", "; [ts3]"]
|
| 216 |
+
if prompts != "" and prompts is not None:
|
| 217 |
+
ps_l = prompts.split('\n')
|
| 218 |
+
for ps in ps_l:
|
| 219 |
+
ps = ps.replace("[target] ", "")
|
| 220 |
+
for str in replace_str:
|
| 221 |
+
ps = ps.replace(str, "||||")
|
| 222 |
+
|
| 223 |
+
p_l = ps.split("||||")
|
| 224 |
+
targets.append(p_l[0])
|
| 225 |
+
anchors.append(p_l[1])
|
| 226 |
+
controls.append(p_l[2])
|
| 227 |
+
scales.append(float(p_l[3]))
|
| 228 |
+
enhances.append(bool(p_l[4]))
|
| 229 |
+
thresholds_list.append([float(p_l[5]), float(p_l[6]), float(p_l[7]), float(p_l[8])])
|
| 230 |
+
|
| 231 |
+
if exists(target) and target != "":
|
| 232 |
+
targets.append(target)
|
| 233 |
+
anchors.append(anchor)
|
| 234 |
+
controls.append(control)
|
| 235 |
+
scales.append(target_scale)
|
| 236 |
+
enhances.append(enhance)
|
| 237 |
+
thresholds_list.append([ts0, ts1, ts2, ts3])
|
| 238 |
+
|
| 239 |
+
return {
|
| 240 |
+
"targets": targets,
|
| 241 |
+
"anchors": anchors,
|
| 242 |
+
"controls": controls,
|
| 243 |
+
"target_scales": scales,
|
| 244 |
+
"enhances": enhances,
|
| 245 |
+
"thresholds_list": thresholds_list
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
from refnet.sampling.manipulation import get_heatmaps
|
| 250 |
+
def visualize_heatmaps(model, reference, manipulation_params, control, ts0, ts1, ts2, ts3):
|
| 251 |
+
if reference is None:
|
| 252 |
+
return []
|
| 253 |
+
|
| 254 |
+
size = reference.size
|
| 255 |
+
if size[0] > maxium_resolution or size[1] > maxium_resolution:
|
| 256 |
+
if size[0] > size[1]:
|
| 257 |
+
size = (maxium_resolution, int(float(maxium_resolution) / size[0] * size[1]))
|
| 258 |
+
else:
|
| 259 |
+
size = (int(float(maxium_resolution) / size[1] * size[0]), maxium_resolution)
|
| 260 |
+
reference = reference.resize(size, Image.BICUBIC)
|
| 261 |
+
|
| 262 |
+
reference = np.array(reference)
|
| 263 |
+
scale_maps = get_heatmaps(model, to_tensor(reference), size[1], size[0],
|
| 264 |
+
control, ts0, ts1, ts2, ts3, **manipulation_params)
|
| 265 |
+
|
| 266 |
+
scale_map = scale_maps[0] + scale_maps[1] + scale_maps[2] + scale_maps[3]
|
| 267 |
+
heatmap = cv2.cvtColor(cv2.applyColorMap(scale_map, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
|
| 268 |
+
result = cv2.addWeighted(reference, 0.3, heatmap, 0.7, 0)
|
| 269 |
+
hu = size[1] // token_length
|
| 270 |
+
wu = size[0] // token_length
|
| 271 |
+
for i in range(16):
|
| 272 |
+
result[i * hu, :] = (0, 0, 0)
|
| 273 |
+
for i in range(16):
|
| 274 |
+
result[:, i * wu] = (0, 0, 0)
|
| 275 |
+
|
| 276 |
+
return [result]
|
backend/style.css
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--primary-color: #9b59b6;
|
| 3 |
+
--primary-light: #d6c6e1;
|
| 4 |
+
--secondary-color: #2ecc71;
|
| 5 |
+
--text-color: #333333;
|
| 6 |
+
--background-color: #f9f9f9;
|
| 7 |
+
--card-bg: #ffffff;
|
| 8 |
+
--border-radius: 10px;
|
| 9 |
+
--shadow-sm: 0 2px 5px rgba(0, 0, 0, 0.05);
|
| 10 |
+
--shadow-md: 0 5px 15px rgba(0, 0, 0, 0.07);
|
| 11 |
+
--shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1);
|
| 12 |
+
--gradient: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
|
| 13 |
+
--input-border: #e0e0e0;
|
| 14 |
+
--input-bg: #ffffff;
|
| 15 |
+
--font-weight-normal: 500;
|
| 16 |
+
--font-weight-bold: 700;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
/* Base styles */
|
| 20 |
+
body, html {
|
| 21 |
+
margin: 0;
|
| 22 |
+
padding: 0;
|
| 23 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
| 24 |
+
font-weight: var(--font-weight-normal);
|
| 25 |
+
background-color: var(--background-color);
|
| 26 |
+
color: var(--text-color);
|
| 27 |
+
width: 100vw;
|
| 28 |
+
overflow-x: hidden;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
* {
|
| 32 |
+
box-sizing: border-box;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/* Force full width layout */
|
| 36 |
+
#main-interface,
|
| 37 |
+
.gradio-app,
|
| 38 |
+
.gradio-container {
|
| 39 |
+
width: 100vw !important;
|
| 40 |
+
max-width: 100vw !important;
|
| 41 |
+
margin: 0 !important;
|
| 42 |
+
padding: 0 !important;
|
| 43 |
+
box-shadow: none !important;
|
| 44 |
+
border: none !important;
|
| 45 |
+
overflow-x: hidden !important;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/* Header styling */
|
| 49 |
+
#header-row {
|
| 50 |
+
background: white;
|
| 51 |
+
padding: 15px 20px;
|
| 52 |
+
margin-bottom: 20px;
|
| 53 |
+
box-shadow: var(--shadow-sm);
|
| 54 |
+
border-bottom: 1px solid rgba(0,0,0,0.05);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.header-container {
|
| 58 |
+
width: 100%;
|
| 59 |
+
display: flex;
|
| 60 |
+
flex-direction: column;
|
| 61 |
+
align-items: center;
|
| 62 |
+
padding: 10px 0;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.app-header {
|
| 66 |
+
display: flex;
|
| 67 |
+
align-items: center;
|
| 68 |
+
gap: 12px;
|
| 69 |
+
margin-bottom: 15px;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.app-header .emoji {
|
| 73 |
+
font-size: 36px;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* Fix for Colorize Diffusion title visibility */
|
| 77 |
+
.gradio-markdown h1,
|
| 78 |
+
.gradio-markdown h2,
|
| 79 |
+
#header-row h1,
|
| 80 |
+
#header-row h2,
|
| 81 |
+
.title-text,
|
| 82 |
+
.app-header .title-text {
|
| 83 |
+
display: inline-block !important;
|
| 84 |
+
visibility: visible !important;
|
| 85 |
+
opacity: 1 !important;
|
| 86 |
+
position: relative !important;
|
| 87 |
+
color: var(--primary-color) !important;
|
| 88 |
+
font-size: 32px !important;
|
| 89 |
+
font-weight: 800 !important;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/* Badge links under the header */
|
| 93 |
+
.paper-links-icons {
|
| 94 |
+
display: flex;
|
| 95 |
+
flex-wrap: wrap;
|
| 96 |
+
justify-content: center;
|
| 97 |
+
gap: 8px;
|
| 98 |
+
margin-top: 5px;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.paper-links-icons a {
|
| 102 |
+
transition: transform 0.2s ease;
|
| 103 |
+
opacity: 0.9;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.paper-links-icons a:hover {
|
| 107 |
+
transform: translateY(-3px);
|
| 108 |
+
opacity: 1;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/* Content layout */
|
| 112 |
+
#content-row {
|
| 113 |
+
padding: 0 20px 20px 20px;
|
| 114 |
+
max-width: 100%;
|
| 115 |
+
margin: 0 auto;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/* Apply bold font to all text elements for better readability */
|
| 119 |
+
p, span, label, button, input, textarea, select, .gradio-button, .gradio-checkbox, .gradio-dropdown, .gradio-textbox {
|
| 120 |
+
font-weight: var(--font-weight-normal);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
/* Make headings bolder */
|
| 124 |
+
h1, h2, h3, h4, h5, h6 {
|
| 125 |
+
font-weight: var(--font-weight-bold);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/* Improved font styling for Gradio UI elements */
|
| 129 |
+
.gradio-container,
|
| 130 |
+
.gradio-container *,
|
| 131 |
+
.gradio-app,
|
| 132 |
+
.gradio-app * {
|
| 133 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 134 |
+
font-weight: 500 !important;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/* Style for labels and slider labels */
|
| 138 |
+
.gradio-container label,
|
| 139 |
+
.gradio-slider label,
|
| 140 |
+
.gradio-checkbox label,
|
| 141 |
+
.gradio-radio label,
|
| 142 |
+
.gradio-dropdown label,
|
| 143 |
+
.gradio-textbox label,
|
| 144 |
+
.gradio-number label,
|
| 145 |
+
.gradio-button,
|
| 146 |
+
.gradio-checkbox span,
|
| 147 |
+
.gradio-radio span {
|
| 148 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 149 |
+
font-weight: 600 !important;
|
| 150 |
+
letter-spacing: 0.01em;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/* Style for buttons */
|
| 154 |
+
button,
|
| 155 |
+
.gradio-button {
|
| 156 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 157 |
+
font-weight: 600 !important;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/* Style for input values */
|
| 161 |
+
input,
|
| 162 |
+
textarea,
|
| 163 |
+
select,
|
| 164 |
+
.gradio-textbox textarea,
|
| 165 |
+
.gradio-number input {
|
| 166 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 167 |
+
font-weight: 500 !important;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/* Better styling for drop areas */
|
| 171 |
+
.upload-box,
|
| 172 |
+
[data-testid="image"] {
|
| 173 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 174 |
+
font-weight: 500 !important;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
/* Additional styling for values in sliders and numbers */
|
| 178 |
+
.wrap .wrap .wrap span {
|
| 179 |
+
font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 180 |
+
font-weight: 600 !important;
|
| 181 |
+
}
|
configs/inference/sdxl.yaml
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-6
|
| 3 |
+
target: refnet.models.colorizerXL.InferenceWrapper
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
timesteps: 1000
|
| 8 |
+
image_size: 128
|
| 9 |
+
channels: 4
|
| 10 |
+
scale_factor: 0.13025
|
| 11 |
+
logits_embed: false
|
| 12 |
+
|
| 13 |
+
unet_config:
|
| 14 |
+
target: refnet.modules.unet.DualCondUNetXL
|
| 15 |
+
params:
|
| 16 |
+
use_checkpoint: True
|
| 17 |
+
in_channels: 4
|
| 18 |
+
out_channels: 4
|
| 19 |
+
model_channels: 320
|
| 20 |
+
adm_in_channels: 512
|
| 21 |
+
# adm_in_channels: 2816
|
| 22 |
+
num_classes: sequential
|
| 23 |
+
attention_resolutions: [4, 2]
|
| 24 |
+
num_res_blocks: 2
|
| 25 |
+
channel_mult: [1, 2, 4]
|
| 26 |
+
num_head_channels: 64
|
| 27 |
+
use_spatial_transformer: true
|
| 28 |
+
use_linear_in_transformer: true
|
| 29 |
+
transformer_depth: [1, 2, 10]
|
| 30 |
+
context_dim: 2048
|
| 31 |
+
|
| 32 |
+
first_stage_config:
|
| 33 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 34 |
+
params:
|
| 35 |
+
embed_dim: 4
|
| 36 |
+
ddconfig:
|
| 37 |
+
double_z: true
|
| 38 |
+
z_channels: 4
|
| 39 |
+
resolution: 512
|
| 40 |
+
in_channels: 3
|
| 41 |
+
out_ch: 3
|
| 42 |
+
ch: 128
|
| 43 |
+
ch_mult: [1, 2, 4, 4]
|
| 44 |
+
num_res_blocks: 2
|
| 45 |
+
attn_resolutions: []
|
| 46 |
+
dropout: 0.0
|
| 47 |
+
|
| 48 |
+
cond_stage_config:
|
| 49 |
+
target: refnet.modules.embedder.HFCLIPVisionModel
|
| 50 |
+
# target: refnet.modules.embedder.FrozenOpenCLIPImageEmbedder
|
| 51 |
+
params:
|
| 52 |
+
arch: ViT-bigG-14
|
| 53 |
+
|
| 54 |
+
control_encoder_config:
|
| 55 |
+
# target: refnet.modules.encoder.MultiEncoder
|
| 56 |
+
target: refnet.modules.encoder.MultiScaleAttentionEncoder
|
| 57 |
+
params:
|
| 58 |
+
in_ch: 3
|
| 59 |
+
model_channels: 320
|
| 60 |
+
ch_mults: [ 1, 2, 4 ]
|
| 61 |
+
|
| 62 |
+
img_embedder_config:
|
| 63 |
+
target: refnet.modules.embedder.WDv14SwinTransformerV2
|
| 64 |
+
|
| 65 |
+
scalar_embedder_config:
|
| 66 |
+
target: refnet.modules.embedder.TimestepEmbedding
|
| 67 |
+
params:
|
| 68 |
+
embed_dim: 256
|
| 69 |
+
|
| 70 |
+
proj_config:
|
| 71 |
+
target: refnet.modules.proj.ClusterConcat
|
| 72 |
+
# target: refnet.modules.proj.RecoveryClusterConcat
|
| 73 |
+
params:
|
| 74 |
+
input_dim: 1280
|
| 75 |
+
c_dim: 1024
|
| 76 |
+
output_dim: 2048
|
| 77 |
+
token_length: 196
|
| 78 |
+
dim_head: 128
|
| 79 |
+
# proj_config:
|
| 80 |
+
# target: refnet.modules.proj.LogitClusterConcat
|
| 81 |
+
# params:
|
| 82 |
+
# input_dim: 1280
|
| 83 |
+
# c_dim: 1024
|
| 84 |
+
# output_dim: 2048
|
| 85 |
+
# token_length: 196
|
| 86 |
+
# dim_head: 128
|
| 87 |
+
# mlp_in_dim: 9083
|
| 88 |
+
# mlp_ckpt_path: pretrained_models/proj.safetensors
|
configs/inference/xlv2.yaml
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-6
|
| 3 |
+
target: refnet.models.v2-colorizerXL.InferenceWrapperXL
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
timesteps: 1000
|
| 8 |
+
image_size: 128
|
| 9 |
+
channels: 4
|
| 10 |
+
scale_factor: 0.13025
|
| 11 |
+
controller: true
|
| 12 |
+
|
| 13 |
+
unet_config:
|
| 14 |
+
target: refnet.modules.unet.DualCondUNetXL
|
| 15 |
+
params:
|
| 16 |
+
use_checkpoint: True
|
| 17 |
+
in_channels: 4
|
| 18 |
+
in_channels_fg: 4
|
| 19 |
+
out_channels: 4
|
| 20 |
+
model_channels: 320
|
| 21 |
+
adm_in_channels: 512
|
| 22 |
+
num_classes: sequential
|
| 23 |
+
attention_resolutions: [4, 2]
|
| 24 |
+
num_res_blocks: 2
|
| 25 |
+
channel_mult: [1, 2, 4]
|
| 26 |
+
num_head_channels: 64
|
| 27 |
+
use_spatial_transformer: true
|
| 28 |
+
use_linear_in_transformer: true
|
| 29 |
+
transformer_depth: [1, 2, 10]
|
| 30 |
+
context_dim: 2048
|
| 31 |
+
map_module: false
|
| 32 |
+
warp_module: false
|
| 33 |
+
style_modulation: false
|
| 34 |
+
|
| 35 |
+
bg_encoder_config:
|
| 36 |
+
target: refnet.modules.unet.ReferenceNet
|
| 37 |
+
params:
|
| 38 |
+
use_checkpoint: True
|
| 39 |
+
in_channels: 6
|
| 40 |
+
model_channels: 320
|
| 41 |
+
adm_in_channels: 1024
|
| 42 |
+
num_classes: sequential
|
| 43 |
+
attention_resolutions: [ 4, 2 ]
|
| 44 |
+
num_res_blocks: 2
|
| 45 |
+
channel_mult: [ 1, 2, 4 ]
|
| 46 |
+
num_head_channels: 64
|
| 47 |
+
use_spatial_transformer: true
|
| 48 |
+
use_linear_in_transformer: true
|
| 49 |
+
disable_cross_attentions: true
|
| 50 |
+
context_dim: 2048
|
| 51 |
+
transformer_depth: [ 1, 2, 10 ]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
first_stage_config:
|
| 55 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 56 |
+
params:
|
| 57 |
+
embed_dim: 4
|
| 58 |
+
ddconfig:
|
| 59 |
+
double_z: true
|
| 60 |
+
z_channels: 4
|
| 61 |
+
resolution: 512
|
| 62 |
+
in_channels: 3
|
| 63 |
+
out_ch: 3
|
| 64 |
+
ch: 128
|
| 65 |
+
ch_mult: [1, 2, 4, 4]
|
| 66 |
+
num_res_blocks: 2
|
| 67 |
+
attn_resolutions: []
|
| 68 |
+
dropout: 0.0
|
| 69 |
+
|
| 70 |
+
cond_stage_config:
|
| 71 |
+
target: refnet.modules.embedder.HFCLIPVisionModel
|
| 72 |
+
params:
|
| 73 |
+
arch: ViT-bigG-14
|
| 74 |
+
|
| 75 |
+
img_embedder_config:
|
| 76 |
+
target: refnet.modules.embedder.WDv14SwinTransformerV2
|
| 77 |
+
|
| 78 |
+
control_encoder_config:
|
| 79 |
+
target: refnet.modules.encoder.MultiScaleAttentionEncoder
|
| 80 |
+
params:
|
| 81 |
+
in_ch: 3
|
| 82 |
+
model_channels: 320
|
| 83 |
+
ch_mults: [1, 2, 4]
|
| 84 |
+
|
| 85 |
+
proj_config:
|
| 86 |
+
target: refnet.modules.proj.ClusterConcat
|
| 87 |
+
# target: refnet.modules.proj.RecoveryClusterConcat
|
| 88 |
+
params:
|
| 89 |
+
input_dim: 1280
|
| 90 |
+
c_dim: 1024
|
| 91 |
+
output_dim: 2048
|
| 92 |
+
token_length: 196
|
| 93 |
+
dim_head: 128
|
| 94 |
+
|
| 95 |
+
scalar_embedder_config:
|
| 96 |
+
target: refnet.modules.embedder.TimestepEmbedding
|
| 97 |
+
params:
|
| 98 |
+
embed_dim: 256
|
| 99 |
+
|
| 100 |
+
lora_config:
|
| 101 |
+
lora_params: [
|
| 102 |
+
{
|
| 103 |
+
label: background,
|
| 104 |
+
root_module: model.diffusion_model,
|
| 105 |
+
target_keys: [ attn2.to_q, attn2.to_k, attn2.to_v ],
|
| 106 |
+
r: 4,
|
| 107 |
+
}
|
| 108 |
+
]
|
configs/scheduler_cfgs/ddim.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beta_start: 0.00085
|
| 2 |
+
beta_end: 0.012
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
clip_sample: false
|
| 5 |
+
steps_offset: 1
|
| 6 |
+
|
| 7 |
+
### Zero-SNR params
|
| 8 |
+
#rescale_betas_zero_snr: True
|
| 9 |
+
#timestep_spacing: "trailing"
|
| 10 |
+
timestep_spacing: "leading"
|
configs/scheduler_cfgs/dpm.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beta_start: 0.00085
|
| 2 |
+
beta_end: 0.012
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
steps_offset: 1
|
| 5 |
+
|
| 6 |
+
### Zero-SNR params
|
| 7 |
+
#rescale_betas_zero_snr: True
|
| 8 |
+
timestep_spacing: "leading"
|
configs/scheduler_cfgs/dpm_sde.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beta_start: 0.00085
|
| 2 |
+
beta_end: 0.012
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
steps_offset: 1
|
| 5 |
+
|
| 6 |
+
### Zero-SNR params
|
| 7 |
+
#rescale_betas_zero_snr: True
|
| 8 |
+
timestep_spacing: "leading"
|
| 9 |
+
algorithm_type: sde-dpmsolver++
|
configs/scheduler_cfgs/lms.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beta_start: 0.00085
|
| 2 |
+
beta_end: 0.012
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
#clip_sample: false
|
| 5 |
+
steps_offset: 1
|
| 6 |
+
|
| 7 |
+
### Zero-SNR params
|
| 8 |
+
#rescale_betas_zero_snr: True
|
| 9 |
+
timestep_spacing: "leading"
|
configs/scheduler_cfgs/pndm.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beta_start: 0.00085
|
| 2 |
+
beta_end: 0.012
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
#clip_sample: false
|
| 5 |
+
steps_offset: 1
|
| 6 |
+
|
| 7 |
+
### Zero-SNR params
|
| 8 |
+
#rescale_betas_zero_snr: True
|
| 9 |
+
#timestep_spacing: "trailing"
|
| 10 |
+
timestep_spacing: "leading"
|
k_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sampling import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_noise_sampler(x, sigmas, seed):
|
| 5 |
+
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
| 6 |
+
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
| 7 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 8 |
+
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
|
k_diffusion/external.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from . import sampling, utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VDenoiser(nn.Module):
|
| 10 |
+
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, inner_model):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.inner_model = inner_model
|
| 15 |
+
self.sigma_data = 1.
|
| 16 |
+
|
| 17 |
+
def get_scalings(self, sigma):
|
| 18 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 19 |
+
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 20 |
+
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 21 |
+
return c_skip, c_out, c_in
|
| 22 |
+
|
| 23 |
+
def sigma_to_t(self, sigma):
|
| 24 |
+
return sigma.atan() / math.pi * 2
|
| 25 |
+
|
| 26 |
+
def t_to_sigma(self, t):
|
| 27 |
+
return (t * math.pi / 2).tan()
|
| 28 |
+
|
| 29 |
+
def loss(self, input, noise, sigma, **kwargs):
|
| 30 |
+
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 31 |
+
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
| 32 |
+
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
| 33 |
+
target = (input - c_skip * noised_input) / c_out
|
| 34 |
+
return (model_output - target).pow(2).flatten(1).mean(1)
|
| 35 |
+
|
| 36 |
+
def forward(self, input, sigma, **kwargs):
|
| 37 |
+
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 38 |
+
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DiscreteSchedule(nn.Module):
|
| 42 |
+
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
| 43 |
+
levels."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, sigmas, quantize):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.register_buffer('sigmas', sigmas)
|
| 48 |
+
self.register_buffer('log_sigmas', sigmas.log())
|
| 49 |
+
self.quantize = quantize
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def sigma_min(self):
|
| 53 |
+
return self.sigmas[0]
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def sigma_max(self):
|
| 57 |
+
return self.sigmas[-1]
|
| 58 |
+
|
| 59 |
+
def get_sigmas(self, n=None):
|
| 60 |
+
if n is None:
|
| 61 |
+
return sampling.append_zero(self.sigmas.flip(0))
|
| 62 |
+
t_max = len(self.sigmas) - 1
|
| 63 |
+
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
| 64 |
+
return sampling.append_zero(self.t_to_sigma(t))
|
| 65 |
+
|
| 66 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 67 |
+
quantize = self.quantize if quantize is None else quantize
|
| 68 |
+
log_sigma = sigma.log()
|
| 69 |
+
dists = log_sigma - self.log_sigmas[:, None]
|
| 70 |
+
if quantize:
|
| 71 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 72 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
| 73 |
+
high_idx = low_idx + 1
|
| 74 |
+
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
| 75 |
+
w = (low - log_sigma) / (low - high)
|
| 76 |
+
w = w.clamp(0, 1)
|
| 77 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 78 |
+
return t.view(sigma.shape)
|
| 79 |
+
|
| 80 |
+
def t_to_sigma(self, t):
|
| 81 |
+
t = t.float()
|
| 82 |
+
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
| 83 |
+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
| 84 |
+
return log_sigma.exp()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
| 88 |
+
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
| 89 |
+
noise)."""
|
| 90 |
+
|
| 91 |
+
def __init__(self, model, alphas_cumprod, quantize):
|
| 92 |
+
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
| 93 |
+
self.inner_model = model
|
| 94 |
+
self.sigma_data = 1.
|
| 95 |
+
|
| 96 |
+
def get_scalings(self, sigma):
|
| 97 |
+
c_out = -sigma
|
| 98 |
+
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 99 |
+
return c_out, c_in
|
| 100 |
+
|
| 101 |
+
def get_eps(self, *args, **kwargs):
|
| 102 |
+
return self.inner_model(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
def loss(self, input, noise, sigma, **kwargs):
|
| 105 |
+
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 106 |
+
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
| 107 |
+
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
| 108 |
+
return (eps - noise).pow(2).flatten(1).mean(1)
|
| 109 |
+
|
| 110 |
+
def forward(self, input, sigma, **kwargs):
|
| 111 |
+
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 112 |
+
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
| 113 |
+
return input + eps * c_out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
| 117 |
+
"""A wrapper for OpenAI diffusion models."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
| 120 |
+
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
| 121 |
+
super().__init__(model, alphas_cumprod, quantize=quantize)
|
| 122 |
+
self.has_learned_sigmas = has_learned_sigmas
|
| 123 |
+
|
| 124 |
+
def get_eps(self, *args, **kwargs):
|
| 125 |
+
model_output = self.inner_model(*args, **kwargs)
|
| 126 |
+
if self.has_learned_sigmas:
|
| 127 |
+
return model_output.chunk(2, dim=1)[0]
|
| 128 |
+
return model_output
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
| 132 |
+
"""A wrapper for CompVis diffusion models."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, model, quantize=False, device='cpu'):
|
| 135 |
+
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
| 136 |
+
self.sigmas = self.sigmas.to(device)
|
| 137 |
+
self.log_sigmas = self.log_sigmas.to(device)
|
| 138 |
+
|
| 139 |
+
def get_eps(self, *args, **kwargs):
|
| 140 |
+
return self.inner_model.apply_model(*args, **kwargs)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
| 144 |
+
"""A wrapper for discrete schedule DDPM models that output v."""
|
| 145 |
+
|
| 146 |
+
def __init__(self, model, alphas_cumprod, quantize):
|
| 147 |
+
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
| 148 |
+
self.inner_model = model
|
| 149 |
+
self.sigma_data = 1.
|
| 150 |
+
|
| 151 |
+
def get_scalings(self, sigma):
|
| 152 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 153 |
+
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 154 |
+
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
| 155 |
+
return c_skip, c_out, c_in
|
| 156 |
+
|
| 157 |
+
def get_v(self, *args, **kwargs):
|
| 158 |
+
return self.inner_model(*args, **kwargs)
|
| 159 |
+
|
| 160 |
+
def loss(self, input, noise, sigma, **kwargs):
|
| 161 |
+
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 162 |
+
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
| 163 |
+
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
| 164 |
+
target = (input - c_skip * noised_input) / c_out
|
| 165 |
+
return (model_output - target).pow(2).flatten(1).mean(1)
|
| 166 |
+
|
| 167 |
+
def forward(self, input, sigma, **kwargs):
|
| 168 |
+
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
| 169 |
+
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
| 173 |
+
"""A wrapper for CompVis diffusion models that output v."""
|
| 174 |
+
|
| 175 |
+
def __init__(self, model, quantize=False, device='cpu'):
|
| 176 |
+
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
| 177 |
+
self.sigmas = self.sigmas.to(device)
|
| 178 |
+
self.log_sigmas = self.log_sigmas.to(device)
|
| 179 |
+
|
| 180 |
+
def get_v(self, x, t, cond, **kwargs):
|
| 181 |
+
return self.inner_model.apply_model(x, t, cond)
|
k_diffusion/sampling.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from scipy import integrate
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchdiffeq import odeint
|
| 7 |
+
import torchsde
|
| 8 |
+
from tqdm.auto import trange, tqdm
|
| 9 |
+
|
| 10 |
+
from . import utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def append_zero(x):
|
| 14 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
| 18 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 19 |
+
ramp = torch.linspace(0, 1, n).to(device)
|
| 20 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 21 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 22 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 23 |
+
return append_zero(sigmas).to(device)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
| 27 |
+
"""Constructs an exponential noise schedule."""
|
| 28 |
+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
| 29 |
+
return append_zero(sigmas)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
| 33 |
+
"""Constructs an polynomial in log sigma noise schedule."""
|
| 34 |
+
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
| 35 |
+
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
| 36 |
+
return append_zero(sigmas)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
| 40 |
+
"""Constructs a continuous VP noise schedule."""
|
| 41 |
+
t = torch.linspace(1, eps_s, n, device=device)
|
| 42 |
+
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
| 43 |
+
return append_zero(sigmas)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_d(x, sigma, denoised):
|
| 47 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
| 48 |
+
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
| 52 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
| 53 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
| 54 |
+
if not eta:
|
| 55 |
+
return sigma_to, 0.
|
| 56 |
+
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
| 57 |
+
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
| 58 |
+
return sigma_down, sigma_up
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def default_noise_sampler(x):
|
| 62 |
+
return lambda sigma, sigma_next: torch.randn_like(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class BatchedBrownianTree:
|
| 66 |
+
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
| 69 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
| 70 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
| 71 |
+
if seed is None:
|
| 72 |
+
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
| 73 |
+
self.batched = True
|
| 74 |
+
try:
|
| 75 |
+
assert len(seed) == x.shape[0]
|
| 76 |
+
w0 = w0[0]
|
| 77 |
+
except TypeError:
|
| 78 |
+
seed = [seed]
|
| 79 |
+
self.batched = False
|
| 80 |
+
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def sort(a, b):
|
| 84 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
| 85 |
+
|
| 86 |
+
def __call__(self, t0, t1):
|
| 87 |
+
t0, t1, sign = self.sort(t0, t1)
|
| 88 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
| 89 |
+
return w if self.batched else w[0]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class BrownianTreeNoiseSampler:
|
| 93 |
+
"""A noise sampler backed by a torchsde.BrownianTree.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
| 97 |
+
random samples.
|
| 98 |
+
sigma_min (float): The low end of the valid interval.
|
| 99 |
+
sigma_max (float): The high end of the valid interval.
|
| 100 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
| 101 |
+
supplied instead of a single integer, then the noise sampler will
|
| 102 |
+
use one BrownianTree per batch item, each with its own seed.
|
| 103 |
+
transform (callable): A function that maps sigma to the sampler's
|
| 104 |
+
internal timestep.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
| 108 |
+
self.transform = transform
|
| 109 |
+
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
| 110 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
| 111 |
+
|
| 112 |
+
def __call__(self, sigma, sigma_next):
|
| 113 |
+
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
| 114 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 119 |
+
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
| 120 |
+
extra_args = {} if extra_args is None else extra_args
|
| 121 |
+
s_in = x.new_ones([x.shape[0]])
|
| 122 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 123 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 124 |
+
eps = torch.randn_like(x) * s_noise
|
| 125 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 126 |
+
if gamma > 0:
|
| 127 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 128 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 129 |
+
d = to_d(x, sigma_hat, denoised)
|
| 130 |
+
if callback is not None:
|
| 131 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 132 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 133 |
+
# Euler method
|
| 134 |
+
x = x + d * dt
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 140 |
+
"""Ancestral sampling with Euler method steps."""
|
| 141 |
+
extra_args = {} if extra_args is None else extra_args
|
| 142 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 143 |
+
s_in = x.new_ones([x.shape[0]])
|
| 144 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 145 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 146 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 147 |
+
if callback is not None:
|
| 148 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 149 |
+
d = to_d(x, sigmas[i], denoised)
|
| 150 |
+
# Euler method
|
| 151 |
+
dt = sigma_down - sigmas[i]
|
| 152 |
+
x = x + d * dt
|
| 153 |
+
if sigmas[i + 1] > 0:
|
| 154 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 160 |
+
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
| 161 |
+
extra_args = {} if extra_args is None else extra_args
|
| 162 |
+
s_in = x.new_ones([x.shape[0]])
|
| 163 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 164 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 165 |
+
eps = torch.randn_like(x) * s_noise
|
| 166 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 167 |
+
if gamma > 0:
|
| 168 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 169 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 170 |
+
d = to_d(x, sigma_hat, denoised)
|
| 171 |
+
if callback is not None:
|
| 172 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 173 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 174 |
+
if sigmas[i + 1] == 0:
|
| 175 |
+
# Euler method
|
| 176 |
+
x = x + d * dt
|
| 177 |
+
else:
|
| 178 |
+
# Heun's method
|
| 179 |
+
x_2 = x + d * dt
|
| 180 |
+
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
| 181 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
| 182 |
+
d_prime = (d + d_2) / 2
|
| 183 |
+
x = x + d_prime * dt
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 189 |
+
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
| 190 |
+
extra_args = {} if extra_args is None else extra_args
|
| 191 |
+
s_in = x.new_ones([x.shape[0]])
|
| 192 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 193 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 194 |
+
eps = torch.randn_like(x) * s_noise
|
| 195 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 196 |
+
if gamma > 0:
|
| 197 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 198 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 199 |
+
d = to_d(x, sigma_hat, denoised)
|
| 200 |
+
if callback is not None:
|
| 201 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 202 |
+
if sigmas[i + 1] == 0:
|
| 203 |
+
# Euler method
|
| 204 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 205 |
+
x = x + d * dt
|
| 206 |
+
else:
|
| 207 |
+
# DPM-Solver-2
|
| 208 |
+
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
| 209 |
+
dt_1 = sigma_mid - sigma_hat
|
| 210 |
+
dt_2 = sigmas[i + 1] - sigma_hat
|
| 211 |
+
x_2 = x + d * dt_1
|
| 212 |
+
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
| 213 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
| 214 |
+
x = x + d_2 * dt_2
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@torch.no_grad()
|
| 219 |
+
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 220 |
+
"""Ancestral sampling with DPM-Solver second-order steps."""
|
| 221 |
+
extra_args = {} if extra_args is None else extra_args
|
| 222 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 223 |
+
s_in = x.new_ones([x.shape[0]])
|
| 224 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 225 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 226 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 227 |
+
if callback is not None:
|
| 228 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 229 |
+
d = to_d(x, sigmas[i], denoised)
|
| 230 |
+
if sigma_down == 0:
|
| 231 |
+
# Euler method
|
| 232 |
+
dt = sigma_down - sigmas[i]
|
| 233 |
+
x = x + d * dt
|
| 234 |
+
else:
|
| 235 |
+
# DPM-Solver-2
|
| 236 |
+
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
| 237 |
+
dt_1 = sigma_mid - sigmas[i]
|
| 238 |
+
dt_2 = sigma_down - sigmas[i]
|
| 239 |
+
x_2 = x + d * dt_1
|
| 240 |
+
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
| 241 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
| 242 |
+
x = x + d_2 * dt_2
|
| 243 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def linear_multistep_coeff(order, t, i, j):
|
| 248 |
+
if order - 1 > i:
|
| 249 |
+
raise ValueError(f'Order {order} too high for step {i}')
|
| 250 |
+
def fn(tau):
|
| 251 |
+
prod = 1.
|
| 252 |
+
for k in range(order):
|
| 253 |
+
if j == k:
|
| 254 |
+
continue
|
| 255 |
+
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
| 256 |
+
return prod
|
| 257 |
+
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@torch.no_grad()
|
| 261 |
+
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
| 262 |
+
extra_args = {} if extra_args is None else extra_args
|
| 263 |
+
s_in = x.new_ones([x.shape[0]])
|
| 264 |
+
sigmas_cpu = sigmas.detach().cpu().numpy()
|
| 265 |
+
ds = []
|
| 266 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 267 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 268 |
+
d = to_d(x, sigmas[i], denoised)
|
| 269 |
+
ds.append(d)
|
| 270 |
+
if len(ds) > order:
|
| 271 |
+
ds.pop(0)
|
| 272 |
+
if callback is not None:
|
| 273 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 274 |
+
cur_order = min(i + 1, order)
|
| 275 |
+
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
| 276 |
+
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.no_grad()
|
| 281 |
+
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
| 282 |
+
extra_args = {} if extra_args is None else extra_args
|
| 283 |
+
s_in = x.new_ones([x.shape[0]])
|
| 284 |
+
v = torch.randint_like(x, 2) * 2 - 1
|
| 285 |
+
fevals = 0
|
| 286 |
+
def ode_fn(sigma, x):
|
| 287 |
+
nonlocal fevals
|
| 288 |
+
with torch.enable_grad():
|
| 289 |
+
x = x[0].detach().requires_grad_()
|
| 290 |
+
denoised = model(x, sigma * s_in, **extra_args)
|
| 291 |
+
d = to_d(x, sigma, denoised)
|
| 292 |
+
fevals += 1
|
| 293 |
+
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
| 294 |
+
d_ll = (v * grad).flatten(1).sum(1)
|
| 295 |
+
return d.detach(), d_ll
|
| 296 |
+
x_min = x, x.new_zeros([x.shape[0]])
|
| 297 |
+
t = x.new_tensor([sigma_min, sigma_max])
|
| 298 |
+
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
| 299 |
+
latent, delta_ll = sol[0][-1], sol[1][-1]
|
| 300 |
+
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
| 301 |
+
return ll_prior + delta_ll, {'fevals': fevals}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class PIDStepSizeController:
|
| 305 |
+
"""A PID controller for ODE adaptive step size control."""
|
| 306 |
+
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
| 307 |
+
self.h = h
|
| 308 |
+
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
| 309 |
+
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
| 310 |
+
self.b3 = dcoeff / order
|
| 311 |
+
self.accept_safety = accept_safety
|
| 312 |
+
self.eps = eps
|
| 313 |
+
self.errs = []
|
| 314 |
+
|
| 315 |
+
def limiter(self, x):
|
| 316 |
+
return 1 + math.atan(x - 1)
|
| 317 |
+
|
| 318 |
+
def propose_step(self, error):
|
| 319 |
+
inv_error = 1 / (float(error) + self.eps)
|
| 320 |
+
if not self.errs:
|
| 321 |
+
self.errs = [inv_error, inv_error, inv_error]
|
| 322 |
+
self.errs[0] = inv_error
|
| 323 |
+
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
| 324 |
+
factor = self.limiter(factor)
|
| 325 |
+
accept = factor >= self.accept_safety
|
| 326 |
+
if accept:
|
| 327 |
+
self.errs[2] = self.errs[1]
|
| 328 |
+
self.errs[1] = self.errs[0]
|
| 329 |
+
self.h *= factor
|
| 330 |
+
return accept
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class DPMSolver(nn.Module):
|
| 334 |
+
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
| 335 |
+
|
| 336 |
+
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
| 337 |
+
super().__init__()
|
| 338 |
+
self.model = model
|
| 339 |
+
self.extra_args = {} if extra_args is None else extra_args
|
| 340 |
+
self.eps_callback = eps_callback
|
| 341 |
+
self.info_callback = info_callback
|
| 342 |
+
|
| 343 |
+
def t(self, sigma):
|
| 344 |
+
return -sigma.log()
|
| 345 |
+
|
| 346 |
+
def sigma(self, t):
|
| 347 |
+
return t.neg().exp()
|
| 348 |
+
|
| 349 |
+
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
| 350 |
+
if key in eps_cache:
|
| 351 |
+
return eps_cache[key], eps_cache
|
| 352 |
+
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
| 353 |
+
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
| 354 |
+
if self.eps_callback is not None:
|
| 355 |
+
self.eps_callback()
|
| 356 |
+
return eps, {key: eps, **eps_cache}
|
| 357 |
+
|
| 358 |
+
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
| 359 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 360 |
+
h = t_next - t
|
| 361 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 362 |
+
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
| 363 |
+
return x_1, eps_cache
|
| 364 |
+
|
| 365 |
+
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
| 366 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 367 |
+
h = t_next - t
|
| 368 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 369 |
+
s1 = t + r1 * h
|
| 370 |
+
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
| 371 |
+
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
| 372 |
+
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
| 373 |
+
return x_2, eps_cache
|
| 374 |
+
|
| 375 |
+
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
| 376 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 377 |
+
h = t_next - t
|
| 378 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 379 |
+
s1 = t + r1 * h
|
| 380 |
+
s2 = t + r2 * h
|
| 381 |
+
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
| 382 |
+
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
| 383 |
+
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
| 384 |
+
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
| 385 |
+
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
| 386 |
+
return x_3, eps_cache
|
| 387 |
+
|
| 388 |
+
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
| 389 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 390 |
+
if not t_end > t_start and eta:
|
| 391 |
+
raise ValueError('eta must be 0 for reverse sampling')
|
| 392 |
+
|
| 393 |
+
m = math.floor(nfe / 3) + 1
|
| 394 |
+
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
| 395 |
+
|
| 396 |
+
if nfe % 3 == 0:
|
| 397 |
+
orders = [3] * (m - 2) + [2, 1]
|
| 398 |
+
else:
|
| 399 |
+
orders = [3] * (m - 1) + [nfe % 3]
|
| 400 |
+
|
| 401 |
+
for i in range(len(orders)):
|
| 402 |
+
eps_cache = {}
|
| 403 |
+
t, t_next = ts[i], ts[i + 1]
|
| 404 |
+
if eta:
|
| 405 |
+
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
| 406 |
+
t_next_ = torch.minimum(t_end, self.t(sd))
|
| 407 |
+
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
| 408 |
+
else:
|
| 409 |
+
t_next_, su = t_next, 0.
|
| 410 |
+
|
| 411 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 412 |
+
denoised = x - self.sigma(t) * eps
|
| 413 |
+
if self.info_callback is not None:
|
| 414 |
+
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
| 415 |
+
|
| 416 |
+
if orders[i] == 1:
|
| 417 |
+
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
| 418 |
+
elif orders[i] == 2:
|
| 419 |
+
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
| 420 |
+
else:
|
| 421 |
+
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
| 422 |
+
|
| 423 |
+
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
| 424 |
+
|
| 425 |
+
return x
|
| 426 |
+
|
| 427 |
+
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
| 428 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 429 |
+
if order not in {2, 3}:
|
| 430 |
+
raise ValueError('order should be 2 or 3')
|
| 431 |
+
forward = t_end > t_start
|
| 432 |
+
if not forward and eta:
|
| 433 |
+
raise ValueError('eta must be 0 for reverse sampling')
|
| 434 |
+
h_init = abs(h_init) * (1 if forward else -1)
|
| 435 |
+
atol = torch.tensor(atol)
|
| 436 |
+
rtol = torch.tensor(rtol)
|
| 437 |
+
s = t_start
|
| 438 |
+
x_prev = x
|
| 439 |
+
accept = True
|
| 440 |
+
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
| 441 |
+
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
| 442 |
+
|
| 443 |
+
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
| 444 |
+
eps_cache = {}
|
| 445 |
+
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
| 446 |
+
if eta:
|
| 447 |
+
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
| 448 |
+
t_ = torch.minimum(t_end, self.t(sd))
|
| 449 |
+
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
| 450 |
+
else:
|
| 451 |
+
t_, su = t, 0.
|
| 452 |
+
|
| 453 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
| 454 |
+
denoised = x - self.sigma(s) * eps
|
| 455 |
+
|
| 456 |
+
if order == 2:
|
| 457 |
+
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
| 458 |
+
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
| 459 |
+
else:
|
| 460 |
+
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
| 461 |
+
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
| 462 |
+
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
| 463 |
+
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
| 464 |
+
accept = pid.propose_step(error)
|
| 465 |
+
if accept:
|
| 466 |
+
x_prev = x_low
|
| 467 |
+
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
| 468 |
+
s = t
|
| 469 |
+
info['n_accept'] += 1
|
| 470 |
+
else:
|
| 471 |
+
info['n_reject'] += 1
|
| 472 |
+
info['nfe'] += order
|
| 473 |
+
info['steps'] += 1
|
| 474 |
+
|
| 475 |
+
if self.info_callback is not None:
|
| 476 |
+
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
| 477 |
+
|
| 478 |
+
return x, info
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@torch.no_grad()
|
| 482 |
+
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
| 483 |
+
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
| 484 |
+
if sigma_min <= 0 or sigma_max <= 0:
|
| 485 |
+
raise ValueError('sigma_min and sigma_max must not be 0')
|
| 486 |
+
with tqdm(total=n, disable=disable) as pbar:
|
| 487 |
+
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
| 488 |
+
if callback is not None:
|
| 489 |
+
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
| 490 |
+
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@torch.no_grad()
|
| 494 |
+
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
| 495 |
+
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
| 496 |
+
if sigma_min <= 0 or sigma_max <= 0:
|
| 497 |
+
raise ValueError('sigma_min and sigma_max must not be 0')
|
| 498 |
+
with tqdm(disable=disable) as pbar:
|
| 499 |
+
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
| 500 |
+
if callback is not None:
|
| 501 |
+
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
| 502 |
+
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
| 503 |
+
if return_info:
|
| 504 |
+
return x, info
|
| 505 |
+
return x
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
@torch.no_grad()
|
| 509 |
+
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 510 |
+
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
| 511 |
+
extra_args = {} if extra_args is None else extra_args
|
| 512 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 513 |
+
s_in = x.new_ones([x.shape[0]])
|
| 514 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 515 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 516 |
+
|
| 517 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 518 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 519 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 520 |
+
if callback is not None:
|
| 521 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 522 |
+
if sigma_down == 0:
|
| 523 |
+
# Euler method
|
| 524 |
+
d = to_d(x, sigmas[i], denoised)
|
| 525 |
+
dt = sigma_down - sigmas[i]
|
| 526 |
+
x = x + d * dt
|
| 527 |
+
else:
|
| 528 |
+
# DPM-Solver++(2S)
|
| 529 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
| 530 |
+
r = 1 / 2
|
| 531 |
+
h = t_next - t
|
| 532 |
+
s = t + r * h
|
| 533 |
+
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
| 534 |
+
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
| 535 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
| 536 |
+
# Noise addition
|
| 537 |
+
if sigmas[i + 1] > 0:
|
| 538 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 539 |
+
return x
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
@torch.no_grad()
|
| 543 |
+
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
| 544 |
+
"""DPM-Solver++ (stochastic)."""
|
| 545 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 546 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
| 547 |
+
extra_args = {} if extra_args is None else extra_args
|
| 548 |
+
s_in = x.new_ones([x.shape[0]])
|
| 549 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 550 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 551 |
+
|
| 552 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 553 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 554 |
+
if callback is not None:
|
| 555 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 556 |
+
if sigmas[i + 1] == 0:
|
| 557 |
+
# Euler method
|
| 558 |
+
d = to_d(x, sigmas[i], denoised)
|
| 559 |
+
dt = sigmas[i + 1] - sigmas[i]
|
| 560 |
+
x = x + d * dt
|
| 561 |
+
else:
|
| 562 |
+
# DPM-Solver++
|
| 563 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 564 |
+
h = t_next - t
|
| 565 |
+
s = t + h * r
|
| 566 |
+
fac = 1 / (2 * r)
|
| 567 |
+
|
| 568 |
+
# Step 1
|
| 569 |
+
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
| 570 |
+
s_ = t_fn(sd)
|
| 571 |
+
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
| 572 |
+
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
| 573 |
+
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
| 574 |
+
|
| 575 |
+
# Step 2
|
| 576 |
+
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
| 577 |
+
t_next_ = t_fn(sd)
|
| 578 |
+
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
| 579 |
+
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
| 580 |
+
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
| 581 |
+
return x
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@torch.no_grad()
|
| 585 |
+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 586 |
+
"""DPM-Solver++(2M)."""
|
| 587 |
+
extra_args = {} if extra_args is None else extra_args
|
| 588 |
+
s_in = x.new_ones([x.shape[0]])
|
| 589 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 590 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 591 |
+
old_denoised = None
|
| 592 |
+
|
| 593 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 594 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 595 |
+
if callback is not None:
|
| 596 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 597 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 598 |
+
h = t_next - t
|
| 599 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
| 600 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
| 601 |
+
else:
|
| 602 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 603 |
+
r = h_last / h
|
| 604 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
| 605 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
| 606 |
+
old_denoised = denoised
|
| 607 |
+
return x
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
@torch.no_grad()
|
| 611 |
+
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
| 612 |
+
"""DPM-Solver++(2M) SDE."""
|
| 613 |
+
|
| 614 |
+
if solver_type not in {'heun', 'midpoint'}:
|
| 615 |
+
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
| 616 |
+
|
| 617 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 618 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
| 619 |
+
extra_args = {} if extra_args is None else extra_args
|
| 620 |
+
s_in = x.new_ones([x.shape[0]])
|
| 621 |
+
|
| 622 |
+
old_denoised = None
|
| 623 |
+
h_last = None
|
| 624 |
+
|
| 625 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 626 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 627 |
+
if callback is not None:
|
| 628 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 629 |
+
if sigmas[i + 1] == 0:
|
| 630 |
+
# Denoising step
|
| 631 |
+
x = denoised
|
| 632 |
+
else:
|
| 633 |
+
# DPM-Solver++(2M) SDE
|
| 634 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
| 635 |
+
h = s - t
|
| 636 |
+
eta_h = eta * h
|
| 637 |
+
|
| 638 |
+
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
| 639 |
+
|
| 640 |
+
if old_denoised is not None:
|
| 641 |
+
r = h_last / h
|
| 642 |
+
if solver_type == 'heun':
|
| 643 |
+
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
| 644 |
+
elif solver_type == 'midpoint':
|
| 645 |
+
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
| 646 |
+
|
| 647 |
+
if eta:
|
| 648 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
| 649 |
+
|
| 650 |
+
old_denoised = denoised
|
| 651 |
+
h_last = h
|
| 652 |
+
return x
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
@torch.no_grad()
|
| 656 |
+
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 657 |
+
"""DPM-Solver++(3M) SDE."""
|
| 658 |
+
|
| 659 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 660 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
| 661 |
+
extra_args = {} if extra_args is None else extra_args
|
| 662 |
+
s_in = x.new_ones([x.shape[0]])
|
| 663 |
+
|
| 664 |
+
denoised_1, denoised_2 = None, None
|
| 665 |
+
h_1, h_2 = None, None
|
| 666 |
+
|
| 667 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 668 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 669 |
+
if callback is not None:
|
| 670 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 671 |
+
if sigmas[i + 1] == 0:
|
| 672 |
+
# Denoising step
|
| 673 |
+
x = denoised
|
| 674 |
+
else:
|
| 675 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
| 676 |
+
h = s - t
|
| 677 |
+
h_eta = h * (eta + 1)
|
| 678 |
+
|
| 679 |
+
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
| 680 |
+
|
| 681 |
+
if h_2 is not None:
|
| 682 |
+
r0 = h_1 / h
|
| 683 |
+
r1 = h_2 / h
|
| 684 |
+
d1_0 = (denoised - denoised_1) / r0
|
| 685 |
+
d1_1 = (denoised_1 - denoised_2) / r1
|
| 686 |
+
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
| 687 |
+
d2 = (d1_0 - d1_1) / (r0 + r1)
|
| 688 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 689 |
+
phi_3 = phi_2 / h_eta - 0.5
|
| 690 |
+
x = x + phi_2 * d1 - phi_3 * d2
|
| 691 |
+
elif h_1 is not None:
|
| 692 |
+
r = h_1 / h
|
| 693 |
+
d = (denoised - denoised_1) / r
|
| 694 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 695 |
+
x = x + phi_2 * d
|
| 696 |
+
|
| 697 |
+
if eta:
|
| 698 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
| 699 |
+
|
| 700 |
+
denoised_1, denoised_2 = denoised, denoised_1
|
| 701 |
+
h_1, h_2 = h, h_1
|
| 702 |
+
return x
|
k_diffusion/utils.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
import hashlib
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import shutil
|
| 6 |
+
import threading
|
| 7 |
+
import urllib
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import safetensors
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn, optim
|
| 14 |
+
from torch.utils import data
|
| 15 |
+
from torchvision.transforms import functional as TF
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def from_pil_image(x):
|
| 19 |
+
"""Converts from a PIL image to a tensor."""
|
| 20 |
+
x = TF.to_tensor(x)
|
| 21 |
+
if x.ndim == 2:
|
| 22 |
+
x = x[..., None]
|
| 23 |
+
return x * 2 - 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def to_pil_image(x):
|
| 27 |
+
"""Converts from a tensor to a PIL image."""
|
| 28 |
+
if x.ndim == 4:
|
| 29 |
+
assert x.shape[0] == 1
|
| 30 |
+
x = x[0]
|
| 31 |
+
if x.shape[0] == 1:
|
| 32 |
+
x = x[0]
|
| 33 |
+
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
| 37 |
+
"""Apply passed in transforms for HuggingFace Datasets."""
|
| 38 |
+
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
| 39 |
+
return {image_key: images}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def append_dims(x, target_dims):
|
| 43 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 44 |
+
dims_to_append = target_dims - x.ndim
|
| 45 |
+
if dims_to_append < 0:
|
| 46 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
| 47 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def n_params(module):
|
| 51 |
+
"""Returns the number of trainable parameters in a module."""
|
| 52 |
+
return sum(p.numel() for p in module.parameters())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def download_file(path, url, digest=None):
|
| 56 |
+
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
| 57 |
+
path = Path(path)
|
| 58 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
if not path.exists():
|
| 60 |
+
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
| 61 |
+
shutil.copyfileobj(response, f)
|
| 62 |
+
if digest is not None:
|
| 63 |
+
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
| 64 |
+
if digest != file_digest:
|
| 65 |
+
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
| 66 |
+
return path
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@contextmanager
|
| 70 |
+
def train_mode(model, mode=True):
|
| 71 |
+
"""A context manager that places a model into training mode and restores
|
| 72 |
+
the previous mode on exit."""
|
| 73 |
+
modes = [module.training for module in model.modules()]
|
| 74 |
+
try:
|
| 75 |
+
yield model.train(mode)
|
| 76 |
+
finally:
|
| 77 |
+
for i, module in enumerate(model.modules()):
|
| 78 |
+
module.training = modes[i]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def eval_mode(model):
|
| 82 |
+
"""A context manager that places a model into evaluation mode and restores
|
| 83 |
+
the previous mode on exit."""
|
| 84 |
+
return train_mode(model, False)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@torch.no_grad()
|
| 88 |
+
def ema_update(model, averaged_model, decay):
|
| 89 |
+
"""Incorporates updated model parameters into an exponential moving averaged
|
| 90 |
+
version of a model. It should be called after each optimizer step."""
|
| 91 |
+
model_params = dict(model.named_parameters())
|
| 92 |
+
averaged_params = dict(averaged_model.named_parameters())
|
| 93 |
+
assert model_params.keys() == averaged_params.keys()
|
| 94 |
+
|
| 95 |
+
for name, param in model_params.items():
|
| 96 |
+
averaged_params[name].lerp_(param, 1 - decay)
|
| 97 |
+
|
| 98 |
+
model_buffers = dict(model.named_buffers())
|
| 99 |
+
averaged_buffers = dict(averaged_model.named_buffers())
|
| 100 |
+
assert model_buffers.keys() == averaged_buffers.keys()
|
| 101 |
+
|
| 102 |
+
for name, buf in model_buffers.items():
|
| 103 |
+
averaged_buffers[name].copy_(buf)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class EMAWarmup:
|
| 107 |
+
"""Implements an EMA warmup using an inverse decay schedule.
|
| 108 |
+
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
| 109 |
+
good values for models you plan to train for a million or more steps (reaches decay
|
| 110 |
+
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
| 111 |
+
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
| 112 |
+
215.4k steps).
|
| 113 |
+
Args:
|
| 114 |
+
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
| 115 |
+
power (float): Exponential factor of EMA warmup. Default: 1.
|
| 116 |
+
min_value (float): The minimum EMA decay rate. Default: 0.
|
| 117 |
+
max_value (float): The maximum EMA decay rate. Default: 1.
|
| 118 |
+
start_at (int): The epoch to start averaging at. Default: 0.
|
| 119 |
+
last_epoch (int): The index of last epoch. Default: 0.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
| 123 |
+
last_epoch=0):
|
| 124 |
+
self.inv_gamma = inv_gamma
|
| 125 |
+
self.power = power
|
| 126 |
+
self.min_value = min_value
|
| 127 |
+
self.max_value = max_value
|
| 128 |
+
self.start_at = start_at
|
| 129 |
+
self.last_epoch = last_epoch
|
| 130 |
+
|
| 131 |
+
def state_dict(self):
|
| 132 |
+
"""Returns the state of the class as a :class:`dict`."""
|
| 133 |
+
return dict(self.__dict__.items())
|
| 134 |
+
|
| 135 |
+
def load_state_dict(self, state_dict):
|
| 136 |
+
"""Loads the class's state.
|
| 137 |
+
Args:
|
| 138 |
+
state_dict (dict): scaler state. Should be an object returned
|
| 139 |
+
from a call to :meth:`state_dict`.
|
| 140 |
+
"""
|
| 141 |
+
self.__dict__.update(state_dict)
|
| 142 |
+
|
| 143 |
+
def get_value(self):
|
| 144 |
+
"""Gets the current EMA decay rate."""
|
| 145 |
+
epoch = max(0, self.last_epoch - self.start_at)
|
| 146 |
+
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
| 147 |
+
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
| 148 |
+
|
| 149 |
+
def step(self):
|
| 150 |
+
"""Updates the step count."""
|
| 151 |
+
self.last_epoch += 1
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class InverseLR(optim.lr_scheduler._LRScheduler):
|
| 155 |
+
"""Implements an inverse decay learning rate schedule with an optional exponential
|
| 156 |
+
warmup. When last_epoch=-1, sets initial lr as lr.
|
| 157 |
+
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
| 158 |
+
(1 / 2)**power of its original value.
|
| 159 |
+
Args:
|
| 160 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 161 |
+
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
| 162 |
+
power (float): Exponential factor of learning rate decay. Default: 1.
|
| 163 |
+
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
| 164 |
+
Default: 0.
|
| 165 |
+
min_lr (float): The minimum learning rate. Default: 0.
|
| 166 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 167 |
+
verbose (bool): If ``True``, prints a message to stdout for
|
| 168 |
+
each update. Default: ``False``.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
| 172 |
+
last_epoch=-1, verbose=False):
|
| 173 |
+
self.inv_gamma = inv_gamma
|
| 174 |
+
self.power = power
|
| 175 |
+
if not 0. <= warmup < 1:
|
| 176 |
+
raise ValueError('Invalid value for warmup')
|
| 177 |
+
self.warmup = warmup
|
| 178 |
+
self.min_lr = min_lr
|
| 179 |
+
super().__init__(optimizer, last_epoch, verbose)
|
| 180 |
+
|
| 181 |
+
def get_lr(self):
|
| 182 |
+
if not self._get_lr_called_within_step:
|
| 183 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
| 184 |
+
"please use `get_last_lr()`.")
|
| 185 |
+
|
| 186 |
+
return self._get_closed_form_lr()
|
| 187 |
+
|
| 188 |
+
def _get_closed_form_lr(self):
|
| 189 |
+
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
| 190 |
+
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
| 191 |
+
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
| 192 |
+
for base_lr in self.base_lrs]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
| 196 |
+
"""Implements an exponential learning rate schedule with an optional exponential
|
| 197 |
+
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
| 198 |
+
continuously by decay (default 0.5) every num_steps steps.
|
| 199 |
+
Args:
|
| 200 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 201 |
+
num_steps (float): The number of steps to decay the learning rate by decay in.
|
| 202 |
+
decay (float): The factor by which to decay the learning rate every num_steps
|
| 203 |
+
steps. Default: 0.5.
|
| 204 |
+
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
| 205 |
+
Default: 0.
|
| 206 |
+
min_lr (float): The minimum learning rate. Default: 0.
|
| 207 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 208 |
+
verbose (bool): If ``True``, prints a message to stdout for
|
| 209 |
+
each update. Default: ``False``.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
| 213 |
+
last_epoch=-1, verbose=False):
|
| 214 |
+
self.num_steps = num_steps
|
| 215 |
+
self.decay = decay
|
| 216 |
+
if not 0. <= warmup < 1:
|
| 217 |
+
raise ValueError('Invalid value for warmup')
|
| 218 |
+
self.warmup = warmup
|
| 219 |
+
self.min_lr = min_lr
|
| 220 |
+
super().__init__(optimizer, last_epoch, verbose)
|
| 221 |
+
|
| 222 |
+
def get_lr(self):
|
| 223 |
+
if not self._get_lr_called_within_step:
|
| 224 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
| 225 |
+
"please use `get_last_lr()`.")
|
| 226 |
+
|
| 227 |
+
return self._get_closed_form_lr()
|
| 228 |
+
|
| 229 |
+
def _get_closed_form_lr(self):
|
| 230 |
+
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
| 231 |
+
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
| 232 |
+
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
| 233 |
+
for base_lr in self.base_lrs]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler):
|
| 237 |
+
"""Implements a constant learning rate schedule with an optional exponential
|
| 238 |
+
warmup. When last_epoch=-1, sets initial lr as lr.
|
| 239 |
+
Args:
|
| 240 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 241 |
+
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
| 242 |
+
Default: 0.
|
| 243 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 244 |
+
verbose (bool): If ``True``, prints a message to stdout for
|
| 245 |
+
each update. Default: ``False``.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False):
|
| 249 |
+
if not 0. <= warmup < 1:
|
| 250 |
+
raise ValueError('Invalid value for warmup')
|
| 251 |
+
self.warmup = warmup
|
| 252 |
+
super().__init__(optimizer, last_epoch, verbose)
|
| 253 |
+
|
| 254 |
+
def get_lr(self):
|
| 255 |
+
if not self._get_lr_called_within_step:
|
| 256 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
| 257 |
+
"please use `get_last_lr()`.")
|
| 258 |
+
|
| 259 |
+
return self._get_closed_form_lr()
|
| 260 |
+
|
| 261 |
+
def _get_closed_form_lr(self):
|
| 262 |
+
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
| 263 |
+
return [warmup * base_lr for base_lr in self.base_lrs]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
|
| 267 |
+
"""Draws stratified samples from a uniform distribution."""
|
| 268 |
+
if groups <= 0:
|
| 269 |
+
raise ValueError(f"groups must be positive, got {groups}")
|
| 270 |
+
if group < 0 or group >= groups:
|
| 271 |
+
raise ValueError(f"group must be in [0, {groups})")
|
| 272 |
+
n = shape[-1] * groups
|
| 273 |
+
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
|
| 274 |
+
u = torch.rand(shape, dtype=dtype, device=device)
|
| 275 |
+
return (offsets + u) / n
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
stratified_settings = threading.local()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@contextmanager
|
| 282 |
+
def enable_stratified(group=0, groups=1, disable=False):
|
| 283 |
+
"""A context manager that enables stratified sampling."""
|
| 284 |
+
try:
|
| 285 |
+
stratified_settings.disable = disable
|
| 286 |
+
stratified_settings.group = group
|
| 287 |
+
stratified_settings.groups = groups
|
| 288 |
+
yield
|
| 289 |
+
finally:
|
| 290 |
+
del stratified_settings.disable
|
| 291 |
+
del stratified_settings.group
|
| 292 |
+
del stratified_settings.groups
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@contextmanager
|
| 296 |
+
def enable_stratified_accelerate(accelerator, disable=False):
|
| 297 |
+
"""A context manager that enables stratified sampling, distributing the strata across
|
| 298 |
+
all processes and gradient accumulation steps using settings from Hugging Face Accelerate."""
|
| 299 |
+
try:
|
| 300 |
+
rank = accelerator.process_index
|
| 301 |
+
world_size = accelerator.num_processes
|
| 302 |
+
acc_steps = accelerator.gradient_state.num_steps
|
| 303 |
+
acc_step = accelerator.step % acc_steps
|
| 304 |
+
group = rank * acc_steps + acc_step
|
| 305 |
+
groups = world_size * acc_steps
|
| 306 |
+
with enable_stratified(group, groups, disable=disable):
|
| 307 |
+
yield
|
| 308 |
+
finally:
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def stratified_with_settings(shape, dtype=None, device=None):
|
| 313 |
+
"""Draws stratified samples from a uniform distribution, using settings from a context
|
| 314 |
+
manager."""
|
| 315 |
+
if not hasattr(stratified_settings, 'disable') or stratified_settings.disable:
|
| 316 |
+
return torch.rand(shape, dtype=dtype, device=device)
|
| 317 |
+
return stratified_uniform(
|
| 318 |
+
shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
| 323 |
+
"""Draws samples from an lognormal distribution."""
|
| 324 |
+
u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7
|
| 325 |
+
return torch.distributions.Normal(loc, scale).icdf(u).exp()
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
| 329 |
+
"""Draws samples from an optionally truncated log-logistic distribution."""
|
| 330 |
+
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
| 331 |
+
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
| 332 |
+
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
| 333 |
+
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
| 334 |
+
u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
| 335 |
+
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
| 339 |
+
"""Draws samples from an log-uniform distribution."""
|
| 340 |
+
min_value = math.log(min_value)
|
| 341 |
+
max_value = math.log(max_value)
|
| 342 |
+
return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
| 346 |
+
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
| 347 |
+
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
| 348 |
+
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
| 349 |
+
u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
| 350 |
+
return torch.tan(u * math.pi / 2) * sigma_data
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
|
| 354 |
+
"""Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
|
| 355 |
+
|
| 356 |
+
def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
|
| 357 |
+
t_min = math.atan(math.exp(-0.5 * logsnr_max))
|
| 358 |
+
t_max = math.atan(math.exp(-0.5 * logsnr_min))
|
| 359 |
+
return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
| 360 |
+
|
| 361 |
+
def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
|
| 362 |
+
shift = 2 * math.log(noise_d / image_d)
|
| 363 |
+
return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
|
| 364 |
+
|
| 365 |
+
def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
|
| 366 |
+
logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max)
|
| 367 |
+
logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max)
|
| 368 |
+
return torch.lerp(logsnr_low, logsnr_high, t)
|
| 369 |
+
|
| 370 |
+
logsnr_min = -2 * math.log(min_value / sigma_data)
|
| 371 |
+
logsnr_max = -2 * math.log(max_value / sigma_data)
|
| 372 |
+
u = stratified_with_settings(shape, device=device, dtype=dtype)
|
| 373 |
+
logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
|
| 374 |
+
return torch.exp(-logsnr / 2) * sigma_data
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
| 378 |
+
"""Draws samples from a split lognormal distribution."""
|
| 379 |
+
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
| 380 |
+
u = torch.rand(shape, device=device, dtype=dtype)
|
| 381 |
+
n_left = n * -scale_1 + loc
|
| 382 |
+
n_right = n * scale_2 + loc
|
| 383 |
+
ratio = scale_1 / (scale_1 + scale_2)
|
| 384 |
+
return torch.where(u < ratio, n_left, n_right).exp()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class FolderOfImages(data.Dataset):
|
| 388 |
+
"""Recursively finds all images in a directory. It does not support
|
| 389 |
+
classes/targets."""
|
| 390 |
+
|
| 391 |
+
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
| 392 |
+
|
| 393 |
+
def __init__(self, root, transform=None):
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.root = Path(root)
|
| 396 |
+
self.transform = nn.Identity() if transform is None else transform
|
| 397 |
+
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
| 398 |
+
|
| 399 |
+
def __repr__(self):
|
| 400 |
+
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
| 401 |
+
|
| 402 |
+
def __len__(self):
|
| 403 |
+
return len(self.paths)
|
| 404 |
+
|
| 405 |
+
def __getitem__(self, key):
|
| 406 |
+
path = self.paths[key]
|
| 407 |
+
with open(path, 'rb') as f:
|
| 408 |
+
image = Image.open(f).convert('RGB')
|
| 409 |
+
image = self.transform(image)
|
| 410 |
+
return image,
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class CSVLogger:
|
| 414 |
+
def __init__(self, filename, columns):
|
| 415 |
+
self.filename = Path(filename)
|
| 416 |
+
self.columns = columns
|
| 417 |
+
if self.filename.exists():
|
| 418 |
+
self.file = open(self.filename, 'a')
|
| 419 |
+
else:
|
| 420 |
+
self.file = open(self.filename, 'w')
|
| 421 |
+
self.write(*self.columns)
|
| 422 |
+
|
| 423 |
+
def write(self, *args):
|
| 424 |
+
print(*args, sep=',', file=self.file, flush=True)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@contextmanager
|
| 428 |
+
def tf32_mode(cudnn=None, matmul=None):
|
| 429 |
+
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
| 430 |
+
cudnn_old = torch.backends.cudnn.allow_tf32
|
| 431 |
+
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
| 432 |
+
try:
|
| 433 |
+
if cudnn is not None:
|
| 434 |
+
torch.backends.cudnn.allow_tf32 = cudnn
|
| 435 |
+
if matmul is not None:
|
| 436 |
+
torch.backends.cuda.matmul.allow_tf32 = matmul
|
| 437 |
+
yield
|
| 438 |
+
finally:
|
| 439 |
+
if cudnn is not None:
|
| 440 |
+
torch.backends.cudnn.allow_tf32 = cudnn_old
|
| 441 |
+
if matmul is not None:
|
| 442 |
+
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def get_safetensors_metadata(path):
|
| 446 |
+
"""Retrieves the metadata from a safetensors file."""
|
| 447 |
+
return safetensors.safe_open(path, "pt").metadata()
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def ema_update_dict(values, updates, decay):
|
| 451 |
+
for k, v in updates.items():
|
| 452 |
+
if k not in values:
|
| 453 |
+
values[k] = v
|
| 454 |
+
else:
|
| 455 |
+
values[k] *= decay
|
| 456 |
+
values[k] += (1 - decay) * v
|
| 457 |
+
return values
|
ldm/modules/diffusionmodules/__init__.py
ADDED
|
File without changes
|
ldm/modules/diffusionmodules/model.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_diffusion + derived encoder decoder
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from typing import Optional, Any
|
| 9 |
+
|
| 10 |
+
from refnet.util import checkpoint_wrapper, default
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import xformers
|
| 14 |
+
import xformers.ops
|
| 15 |
+
|
| 16 |
+
XFORMERS_IS_AVAILBLE = True
|
| 17 |
+
attn_processor = xformers.ops.memory_efficient_attention
|
| 18 |
+
except:
|
| 19 |
+
XFORMERS_IS_AVAILBLE = False
|
| 20 |
+
attn_processor = F.scaled_dot_product_attention
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
| 24 |
+
"""
|
| 25 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 26 |
+
From Fairseq.
|
| 27 |
+
Build sinusoidal embeddings.
|
| 28 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 29 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 30 |
+
"""
|
| 31 |
+
assert len(timesteps.shape) == 1
|
| 32 |
+
|
| 33 |
+
half_dim = embedding_dim // 2
|
| 34 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 35 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 36 |
+
emb = emb.to(device=timesteps.device)
|
| 37 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 38 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 39 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 40 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
| 41 |
+
return emb
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def nonlinearity(x):
|
| 45 |
+
# swish
|
| 46 |
+
return x*torch.sigmoid(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def Normalize(in_channels, num_groups=32):
|
| 50 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Upsample(nn.Module):
|
| 54 |
+
def __init__(self, in_channels, with_conv):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.with_conv = with_conv
|
| 57 |
+
if self.with_conv:
|
| 58 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 59 |
+
in_channels,
|
| 60 |
+
kernel_size=3,
|
| 61 |
+
stride=1,
|
| 62 |
+
padding=1)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 66 |
+
if self.with_conv:
|
| 67 |
+
x = self.conv(x)
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Downsample(nn.Module):
|
| 72 |
+
def __init__(self, in_channels, with_conv):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.with_conv = with_conv
|
| 75 |
+
if self.with_conv:
|
| 76 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 77 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 78 |
+
in_channels,
|
| 79 |
+
kernel_size=3,
|
| 80 |
+
stride=2,
|
| 81 |
+
padding=0)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if self.with_conv:
|
| 85 |
+
pad = (0,1,0,1)
|
| 86 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 87 |
+
x = self.conv(x)
|
| 88 |
+
else:
|
| 89 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ResnetBlock(nn.Module):
|
| 94 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 95 |
+
dropout, temb_channels=512):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.in_channels = in_channels
|
| 98 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 99 |
+
self.out_channels = out_channels
|
| 100 |
+
self.use_conv_shortcut = conv_shortcut
|
| 101 |
+
|
| 102 |
+
self.norm1 = Normalize(in_channels)
|
| 103 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 104 |
+
out_channels,
|
| 105 |
+
kernel_size=3,
|
| 106 |
+
stride=1,
|
| 107 |
+
padding=1)
|
| 108 |
+
if temb_channels > 0:
|
| 109 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 110 |
+
out_channels)
|
| 111 |
+
self.norm2 = Normalize(out_channels)
|
| 112 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 113 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 114 |
+
out_channels,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
stride=1,
|
| 117 |
+
padding=1)
|
| 118 |
+
if self.in_channels != self.out_channels:
|
| 119 |
+
if self.use_conv_shortcut:
|
| 120 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 121 |
+
out_channels,
|
| 122 |
+
kernel_size=3,
|
| 123 |
+
stride=1,
|
| 124 |
+
padding=1)
|
| 125 |
+
else:
|
| 126 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
kernel_size=1,
|
| 129 |
+
stride=1,
|
| 130 |
+
padding=0)
|
| 131 |
+
|
| 132 |
+
@checkpoint_wrapper
|
| 133 |
+
def forward(self, x, temb=None):
|
| 134 |
+
h = x
|
| 135 |
+
h = self.norm1(h)
|
| 136 |
+
h = nonlinearity(h)
|
| 137 |
+
h = self.conv1(h)
|
| 138 |
+
|
| 139 |
+
if temb is not None:
|
| 140 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 141 |
+
|
| 142 |
+
h = self.norm2(h)
|
| 143 |
+
h = nonlinearity(h)
|
| 144 |
+
h = self.dropout(h)
|
| 145 |
+
h = self.conv2(h)
|
| 146 |
+
|
| 147 |
+
if self.in_channels != self.out_channels:
|
| 148 |
+
if self.use_conv_shortcut:
|
| 149 |
+
x = self.conv_shortcut(x)
|
| 150 |
+
else:
|
| 151 |
+
x = self.nin_shortcut(x)
|
| 152 |
+
|
| 153 |
+
return x+h
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class AttnBlock(nn.Module):
|
| 157 |
+
def __init__(self, in_channels):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.in_channels = in_channels
|
| 160 |
+
|
| 161 |
+
self.norm = Normalize(in_channels)
|
| 162 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 163 |
+
in_channels,
|
| 164 |
+
kernel_size=1,
|
| 165 |
+
stride=1,
|
| 166 |
+
padding=0)
|
| 167 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 168 |
+
in_channels,
|
| 169 |
+
kernel_size=1,
|
| 170 |
+
stride=1,
|
| 171 |
+
padding=0)
|
| 172 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 173 |
+
in_channels,
|
| 174 |
+
kernel_size=1,
|
| 175 |
+
stride=1,
|
| 176 |
+
padding=0)
|
| 177 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 178 |
+
in_channels,
|
| 179 |
+
kernel_size=1,
|
| 180 |
+
stride=1,
|
| 181 |
+
padding=0)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
h_ = x
|
| 185 |
+
h_ = self.norm(h_)
|
| 186 |
+
q = self.q(h_)
|
| 187 |
+
k = self.k(h_)
|
| 188 |
+
v = self.v(h_)
|
| 189 |
+
|
| 190 |
+
# compute attention
|
| 191 |
+
b,c,h,w = q.shape
|
| 192 |
+
q = q.reshape(b,c,h*w)
|
| 193 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 194 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 195 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 196 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 197 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 198 |
+
|
| 199 |
+
# attend to values
|
| 200 |
+
v = v.reshape(b,c,h*w)
|
| 201 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 202 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 203 |
+
h_ = h_.reshape(b,c,h,w)
|
| 204 |
+
|
| 205 |
+
h_ = self.proj_out(h_)
|
| 206 |
+
|
| 207 |
+
return x+h_
|
| 208 |
+
|
| 209 |
+
class MemoryEfficientAttnBlock(nn.Module):
|
| 210 |
+
"""
|
| 211 |
+
Uses xformers efficient implementation,
|
| 212 |
+
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 213 |
+
Note: this is a single-head self-attention operation
|
| 214 |
+
"""
|
| 215 |
+
#
|
| 216 |
+
def __init__(self, in_channels, head_dim=None):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.in_channels = in_channels
|
| 219 |
+
self.head_dim = default(head_dim, in_channels)
|
| 220 |
+
self.heads = in_channels // self.head_dim
|
| 221 |
+
# if self.head_dim > 256:
|
| 222 |
+
# self.attn_processor = F.scaled_dot_product_attention
|
| 223 |
+
# else:
|
| 224 |
+
self.attn_processor = attn_processor
|
| 225 |
+
|
| 226 |
+
self.norm = Normalize(in_channels)
|
| 227 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 228 |
+
in_channels,
|
| 229 |
+
kernel_size=1,
|
| 230 |
+
stride=1,
|
| 231 |
+
padding=0)
|
| 232 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 233 |
+
in_channels,
|
| 234 |
+
kernel_size=1,
|
| 235 |
+
stride=1,
|
| 236 |
+
padding=0)
|
| 237 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 238 |
+
in_channels,
|
| 239 |
+
kernel_size=1,
|
| 240 |
+
stride=1,
|
| 241 |
+
padding=0)
|
| 242 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 243 |
+
in_channels,
|
| 244 |
+
kernel_size=1,
|
| 245 |
+
stride=1,
|
| 246 |
+
padding=0)
|
| 247 |
+
self.attention_op: Optional[Any] = None
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
h_ = x
|
| 251 |
+
h_ = self.norm(h_)
|
| 252 |
+
q = self.q(h_)
|
| 253 |
+
k = self.k(h_)
|
| 254 |
+
v = self.v(h_)
|
| 255 |
+
|
| 256 |
+
# compute attention
|
| 257 |
+
B, C, H, W = q.shape
|
| 258 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
| 259 |
+
|
| 260 |
+
q, k, v = map(
|
| 261 |
+
lambda t: t.unsqueeze(3)
|
| 262 |
+
.reshape(B, -1, self.heads, C)
|
| 263 |
+
.permute(0, 2, 1, 3)
|
| 264 |
+
.reshape(B * self.heads, -1, C)
|
| 265 |
+
.contiguous(),
|
| 266 |
+
(q, k, v),
|
| 267 |
+
)
|
| 268 |
+
out = self.attn_processor(q, k, v)
|
| 269 |
+
|
| 270 |
+
out = (
|
| 271 |
+
out.unsqueeze(0)
|
| 272 |
+
.reshape(B, 1, out.shape[1], C)
|
| 273 |
+
.permute(0, 2, 1, 3)
|
| 274 |
+
.reshape(B, out.shape[1], C)
|
| 275 |
+
)
|
| 276 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
| 277 |
+
out = self.proj_out(out)
|
| 278 |
+
return x+out
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def make_attn(in_channels, **kwargs):
|
| 282 |
+
return MemoryEfficientAttnBlock(in_channels)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class Encoder(nn.Module):
|
| 287 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 288 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 289 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
| 290 |
+
checkpoint=True, **ignore_kwargs):
|
| 291 |
+
super().__init__()
|
| 292 |
+
if use_linear_attn: attn_type = "linear"
|
| 293 |
+
self.ch = ch
|
| 294 |
+
self.temb_ch = 0
|
| 295 |
+
self.num_resolutions = len(ch_mult)
|
| 296 |
+
self.num_res_blocks = num_res_blocks
|
| 297 |
+
self.resolution = resolution
|
| 298 |
+
self.in_channels = in_channels
|
| 299 |
+
|
| 300 |
+
# downsampling
|
| 301 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 302 |
+
self.ch,
|
| 303 |
+
kernel_size=3,
|
| 304 |
+
stride=1,
|
| 305 |
+
padding=1)
|
| 306 |
+
|
| 307 |
+
curr_res = resolution
|
| 308 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 309 |
+
self.in_ch_mult = in_ch_mult
|
| 310 |
+
self.down = nn.ModuleList()
|
| 311 |
+
for i_level in range(self.num_resolutions):
|
| 312 |
+
block = nn.ModuleList()
|
| 313 |
+
attn = nn.ModuleList()
|
| 314 |
+
block_in = ch*in_ch_mult[i_level]
|
| 315 |
+
block_out = ch*ch_mult[i_level]
|
| 316 |
+
for i_block in range(self.num_res_blocks):
|
| 317 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 318 |
+
out_channels=block_out,
|
| 319 |
+
temb_channels=self.temb_ch,
|
| 320 |
+
dropout=dropout))
|
| 321 |
+
block_in = block_out
|
| 322 |
+
if curr_res in attn_resolutions:
|
| 323 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 324 |
+
down = nn.Module()
|
| 325 |
+
down.block = block
|
| 326 |
+
down.attn = attn
|
| 327 |
+
if i_level != self.num_resolutions-1:
|
| 328 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 329 |
+
curr_res = curr_res // 2
|
| 330 |
+
self.down.append(down)
|
| 331 |
+
|
| 332 |
+
# middle
|
| 333 |
+
self.mid = nn.Module()
|
| 334 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 335 |
+
out_channels=block_in,
|
| 336 |
+
temb_channels=self.temb_ch,
|
| 337 |
+
dropout=dropout)
|
| 338 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 339 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 340 |
+
out_channels=block_in,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout)
|
| 343 |
+
|
| 344 |
+
# end
|
| 345 |
+
self.norm_out = Normalize(block_in)
|
| 346 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 347 |
+
2*z_channels if double_z else z_channels,
|
| 348 |
+
kernel_size=3,
|
| 349 |
+
stride=1,
|
| 350 |
+
padding=1)
|
| 351 |
+
self.checkpoint = checkpoint
|
| 352 |
+
|
| 353 |
+
@checkpoint_wrapper
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
# timestep embedding
|
| 356 |
+
temb = None
|
| 357 |
+
|
| 358 |
+
# downsampling
|
| 359 |
+
hs = [self.conv_in(x)]
|
| 360 |
+
for i_level in range(self.num_resolutions):
|
| 361 |
+
for i_block in range(self.num_res_blocks):
|
| 362 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 363 |
+
if len(self.down[i_level].attn) > 0:
|
| 364 |
+
h = self.down[i_level].attn[i_block](h)
|
| 365 |
+
hs.append(h)
|
| 366 |
+
if i_level != self.num_resolutions-1:
|
| 367 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 368 |
+
|
| 369 |
+
# middle
|
| 370 |
+
h = hs[-1]
|
| 371 |
+
h = self.mid.block_1(h, temb)
|
| 372 |
+
h = self.mid.attn_1(h)
|
| 373 |
+
h = self.mid.block_2(h, temb)
|
| 374 |
+
|
| 375 |
+
# end
|
| 376 |
+
h = self.norm_out(h)
|
| 377 |
+
h = nonlinearity(h)
|
| 378 |
+
h = self.conv_out(h)
|
| 379 |
+
return h
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class Decoder(nn.Module):
|
| 383 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 384 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 385 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
| 386 |
+
attn_type="vanilla", checkpoint=True, **ignorekwargs):
|
| 387 |
+
super().__init__()
|
| 388 |
+
if use_linear_attn: attn_type = "linear"
|
| 389 |
+
self.ch = ch
|
| 390 |
+
self.temb_ch = 0
|
| 391 |
+
self.num_resolutions = len(ch_mult)
|
| 392 |
+
self.num_res_blocks = num_res_blocks
|
| 393 |
+
self.resolution = resolution
|
| 394 |
+
self.in_channels = in_channels
|
| 395 |
+
self.give_pre_end = give_pre_end
|
| 396 |
+
self.tanh_out = tanh_out
|
| 397 |
+
|
| 398 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 399 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 400 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 401 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 402 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 403 |
+
|
| 404 |
+
# z to block_in
|
| 405 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 406 |
+
block_in,
|
| 407 |
+
kernel_size=3,
|
| 408 |
+
stride=1,
|
| 409 |
+
padding=1)
|
| 410 |
+
|
| 411 |
+
# middle
|
| 412 |
+
self.mid = nn.Module()
|
| 413 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 414 |
+
out_channels=block_in,
|
| 415 |
+
temb_channels=self.temb_ch,
|
| 416 |
+
dropout=dropout)
|
| 417 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 418 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 419 |
+
out_channels=block_in,
|
| 420 |
+
temb_channels=self.temb_ch,
|
| 421 |
+
dropout=dropout)
|
| 422 |
+
|
| 423 |
+
# upsampling
|
| 424 |
+
self.up = nn.ModuleList()
|
| 425 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 426 |
+
block = nn.ModuleList()
|
| 427 |
+
attn = nn.ModuleList()
|
| 428 |
+
block_out = ch*ch_mult[i_level]
|
| 429 |
+
for i_block in range(self.num_res_blocks+1):
|
| 430 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 431 |
+
out_channels=block_out,
|
| 432 |
+
temb_channels=self.temb_ch,
|
| 433 |
+
dropout=dropout))
|
| 434 |
+
block_in = block_out
|
| 435 |
+
if curr_res in attn_resolutions:
|
| 436 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 437 |
+
up = nn.Module()
|
| 438 |
+
up.block = block
|
| 439 |
+
up.attn = attn
|
| 440 |
+
if i_level != 0:
|
| 441 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 442 |
+
curr_res = curr_res * 2
|
| 443 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 444 |
+
|
| 445 |
+
# end
|
| 446 |
+
self.norm_out = Normalize(block_in)
|
| 447 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 448 |
+
out_ch,
|
| 449 |
+
kernel_size=3,
|
| 450 |
+
stride=1,
|
| 451 |
+
padding=1)
|
| 452 |
+
self.checkpoint = checkpoint
|
| 453 |
+
|
| 454 |
+
@checkpoint_wrapper
|
| 455 |
+
def forward(self, z):
|
| 456 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 457 |
+
self.last_z_shape = z.shape
|
| 458 |
+
|
| 459 |
+
# timestep embedding
|
| 460 |
+
temb = None
|
| 461 |
+
|
| 462 |
+
# z to block_in
|
| 463 |
+
h = self.conv_in(z)
|
| 464 |
+
|
| 465 |
+
# middle
|
| 466 |
+
h = self.mid.block_1(h, temb)
|
| 467 |
+
h = self.mid.attn_1(h)
|
| 468 |
+
h = self.mid.block_2(h, temb)
|
| 469 |
+
|
| 470 |
+
# upsampling
|
| 471 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 472 |
+
for i_block in range(self.num_res_blocks+1):
|
| 473 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 474 |
+
if len(self.up[i_level].attn) > 0:
|
| 475 |
+
h = self.up[i_level].attn[i_block](h)
|
| 476 |
+
if i_level != 0:
|
| 477 |
+
h = self.up[i_level].upsample(h)
|
| 478 |
+
|
| 479 |
+
# end
|
| 480 |
+
if self.give_pre_end:
|
| 481 |
+
return h
|
| 482 |
+
|
| 483 |
+
h = self.norm_out(h)
|
| 484 |
+
h = nonlinearity(h)
|
| 485 |
+
h = self.conv_out(h)
|
| 486 |
+
if self.tanh_out:
|
| 487 |
+
h = torch.tanh(h)
|
| 488 |
+
return h
|
ldm/modules/distributions/__init__.py
ADDED
|
File without changes
|
ldm/modules/distributions/distributions.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AbstractDistribution:
|
| 6 |
+
def sample(self):
|
| 7 |
+
raise NotImplementedError()
|
| 8 |
+
|
| 9 |
+
def mode(self):
|
| 10 |
+
raise NotImplementedError()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiracDistribution(AbstractDistribution):
|
| 14 |
+
def __init__(self, value):
|
| 15 |
+
self.value = value
|
| 16 |
+
|
| 17 |
+
def sample(self):
|
| 18 |
+
return self.value
|
| 19 |
+
|
| 20 |
+
def mode(self):
|
| 21 |
+
return self.value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DiagonalGaussianDistribution(object):
|
| 25 |
+
def __init__(self, parameters, deterministic=False):
|
| 26 |
+
self.parameters = parameters
|
| 27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 29 |
+
self.deterministic = deterministic
|
| 30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 31 |
+
self.var = torch.exp(self.logvar)
|
| 32 |
+
if self.deterministic:
|
| 33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 34 |
+
|
| 35 |
+
def sample(self):
|
| 36 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
def kl(self, other=None):
|
| 40 |
+
if self.deterministic:
|
| 41 |
+
return torch.Tensor([0.])
|
| 42 |
+
else:
|
| 43 |
+
if other is None:
|
| 44 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
| 45 |
+
+ self.var - 1.0 - self.logvar,
|
| 46 |
+
dim=[1, 2, 3])
|
| 47 |
+
else:
|
| 48 |
+
return 0.5 * torch.sum(
|
| 49 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 50 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 51 |
+
dim=[1, 2, 3])
|
| 52 |
+
|
| 53 |
+
def nll(self, sample, dims=[1,2,3]):
|
| 54 |
+
if self.deterministic:
|
| 55 |
+
return torch.Tensor([0.])
|
| 56 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 57 |
+
return 0.5 * torch.sum(
|
| 58 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 59 |
+
dim=dims)
|
| 60 |
+
|
| 61 |
+
def mode(self):
|
| 62 |
+
return self.mean
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
| 66 |
+
"""
|
| 67 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
| 68 |
+
Compute the KL divergence between two gaussians.
|
| 69 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
| 70 |
+
scalars, among other use cases.
|
| 71 |
+
"""
|
| 72 |
+
tensor = None
|
| 73 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
| 74 |
+
if isinstance(obj, torch.Tensor):
|
| 75 |
+
tensor = obj
|
| 76 |
+
break
|
| 77 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
| 78 |
+
|
| 79 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
| 80 |
+
# Tensors, but it does not work for torch.exp().
|
| 81 |
+
logvar1, logvar2 = [
|
| 82 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
| 83 |
+
for x in (logvar1, logvar2)
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
return 0.5 * (
|
| 87 |
+
-1.0
|
| 88 |
+
+ logvar2
|
| 89 |
+
- logvar1
|
| 90 |
+
+ torch.exp(logvar1 - logvar2)
|
| 91 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
| 92 |
+
)
|
preprocessor/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch.hub
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchvision.transforms.functional as tf
|
| 7 |
+
import functools
|
| 8 |
+
|
| 9 |
+
model_path = "preprocessor/weights"
|
| 10 |
+
os.environ["HF_HOME"] = model_path
|
| 11 |
+
torch.hub.set_dir(model_path)
|
| 12 |
+
|
| 13 |
+
from torch.hub import download_url_to_file
|
| 14 |
+
from transformers import AutoModelForImageSegmentation
|
| 15 |
+
from .anime2sketch import UnetGenerator
|
| 16 |
+
from .manga_line_extractor import res_skip
|
| 17 |
+
from .sketchKeras import SketchKeras
|
| 18 |
+
from .sk_model import LineartDetector
|
| 19 |
+
from .anime_segment import ISNetDIS
|
| 20 |
+
from refnet.util import load_weights
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class NoneMaskExtractor(nn.Module):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.identity = nn.Identity()
|
| 27 |
+
|
| 28 |
+
def proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
|
| 29 |
+
b, c, h, w = x.shape
|
| 30 |
+
return torch.zeros([b, 1, h, w], device=x.device)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return self.proceed(x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
remote_model_dict = {
|
| 37 |
+
"lineart": "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth",
|
| 38 |
+
"lineart_denoise": "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth",
|
| 39 |
+
"lineart_keras": "https://huggingface.co/tellurion/line_extractor/resolve/main/model.pth",
|
| 40 |
+
"lineart_sk": "https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth",
|
| 41 |
+
"ISNet": "https://huggingface.co/tellurion/line_extractor/resolve/main/isnetis.safetensors",
|
| 42 |
+
"ISNet-sketch": "https://huggingface.co/tellurion/line_extractor/resolve/main/sketch-segment.safetensors"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
BiRefNet_dict = {
|
| 46 |
+
"rmbg-v2": ("briaai/RMBG-2.0", 1024),
|
| 47 |
+
"BiRefNet": ("ZhengPeng7/BiRefNet", 1024),
|
| 48 |
+
"BiRefNet_HR": ("ZhengPeng7/BiRefNet_HR", 2048)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def rmbg_proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
|
| 52 |
+
b, c, h, w = x.shape
|
| 53 |
+
x = (x + 1.0) / 2.
|
| 54 |
+
x = tf.resize(x, [self.image_size, self.image_size])
|
| 55 |
+
x = tf.normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 56 |
+
x = self(x)[-1].sigmoid()
|
| 57 |
+
x = tf.resize(x, [h, w])
|
| 58 |
+
|
| 59 |
+
if th and tw:
|
| 60 |
+
x = tf.pad(x, padding=[(th-h)//2, (tw-w)//2])
|
| 61 |
+
if dilate:
|
| 62 |
+
x = F.max_pool2d(x, kernel_size=21, stride=1, padding=10)
|
| 63 |
+
# x = F.max_pool2d(x, kernel_size=11, stride=1, padding=5)
|
| 64 |
+
# x = mask_expansion(x, 60, 40)
|
| 65 |
+
x = torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x))
|
| 66 |
+
x = x.clamp(0, 1)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def create_model(model_name="lineart"):
|
| 72 |
+
"""Create a model for anime2sketch
|
| 73 |
+
hardcoding the options for simplicity
|
| 74 |
+
"""
|
| 75 |
+
if model_name == "none":
|
| 76 |
+
return NoneMaskExtractor().eval()
|
| 77 |
+
|
| 78 |
+
if model_name in BiRefNet_dict.keys():
|
| 79 |
+
model = AutoModelForImageSegmentation.from_pretrained(
|
| 80 |
+
BiRefNet_dict[model_name][0],
|
| 81 |
+
trust_remote_code = True,
|
| 82 |
+
cache_dir = model_path,
|
| 83 |
+
device_map = None,
|
| 84 |
+
low_cpu_mem_usage = False,
|
| 85 |
+
)
|
| 86 |
+
model.eval()
|
| 87 |
+
model.image_size = BiRefNet_dict[model_name][1]
|
| 88 |
+
model.proceed = rmbg_proceed.__get__(model, model.__class__)
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
assert model_name in remote_model_dict.keys()
|
| 92 |
+
remote_path = remote_model_dict[model_name]
|
| 93 |
+
basename = os.path.basename(remote_path)
|
| 94 |
+
ckpt_path = os.path.join(model_path, basename)
|
| 95 |
+
|
| 96 |
+
if not os.path.exists(model_path):
|
| 97 |
+
os.makedirs(model_path)
|
| 98 |
+
|
| 99 |
+
if not os.path.exists(ckpt_path):
|
| 100 |
+
cache_path = "preprocessor/weights/weights.tmp"
|
| 101 |
+
download_url_to_file(remote_path, dst=cache_path)
|
| 102 |
+
os.rename(cache_path, ckpt_path)
|
| 103 |
+
|
| 104 |
+
if model_name == "lineart":
|
| 105 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
| 106 |
+
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
| 107 |
+
elif model_name == "lineart_denoise":
|
| 108 |
+
model = res_skip()
|
| 109 |
+
elif model_name == "lineart_keras":
|
| 110 |
+
model = SketchKeras()
|
| 111 |
+
elif model_name == "lineart_sk":
|
| 112 |
+
model = LineartDetector()
|
| 113 |
+
elif model_name == "ISNet" or model_name == "ISNet-sketch":
|
| 114 |
+
model = ISNetDIS()
|
| 115 |
+
else:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
ckpt = load_weights(ckpt_path)
|
| 119 |
+
for key in list(ckpt.keys()):
|
| 120 |
+
if 'module.' in key:
|
| 121 |
+
ckpt[key.replace('module.', '')] = ckpt[key]
|
| 122 |
+
del ckpt[key]
|
| 123 |
+
model.load_state_dict(ckpt)
|
| 124 |
+
return model.eval()
|
preprocessor/anime2sketch.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import functools
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Anime2Sketch: A sketch extractor for illustration, anime art, manga
|
| 8 |
+
Author: Xiaoyu Zhang
|
| 9 |
+
Github link: https://github.com/Mukosame/Anime2Sketch
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def to_tensor(x, inverse=False):
|
| 13 |
+
x = transforms.ToTensor()(x).unsqueeze(0)
|
| 14 |
+
x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
|
| 15 |
+
return x if not inverse else -x
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class UnetGenerator(nn.Module):
|
| 19 |
+
"""Create a Unet-based generator"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 22 |
+
"""Construct a Unet generator
|
| 23 |
+
Parameters:
|
| 24 |
+
input_nc (int) -- the number of channels in input images
|
| 25 |
+
output_nc (int) -- the number of channels in output images
|
| 26 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
| 27 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
| 28 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 29 |
+
norm_layer -- normalization layer
|
| 30 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
| 31 |
+
It is a recursive process.
|
| 32 |
+
"""
|
| 33 |
+
super(UnetGenerator, self).__init__()
|
| 34 |
+
# construct unet structure
|
| 35 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
| 36 |
+
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
| 37 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
| 38 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
| 39 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 40 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 41 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 42 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
| 43 |
+
|
| 44 |
+
def forward(self, input):
|
| 45 |
+
"""Standard forward"""
|
| 46 |
+
return self.model(input)
|
| 47 |
+
|
| 48 |
+
def proceed(self, img):
|
| 49 |
+
sketch = self(to_tensor(img))
|
| 50 |
+
return -sketch
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class UnetSkipConnectionBlock(nn.Module):
|
| 54 |
+
"""Defines the Unet submodule with skip connection.
|
| 55 |
+
X -------------------identity----------------------
|
| 56 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 60 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 61 |
+
"""Construct a Unet submodule with skip connections.
|
| 62 |
+
Parameters:
|
| 63 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
| 64 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
| 65 |
+
input_nc (int) -- the number of channels in input images/features
|
| 66 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
| 67 |
+
outermost (bool) -- if this module is the outermost module
|
| 68 |
+
innermost (bool) -- if this module is the innermost module
|
| 69 |
+
norm_layer -- normalization layer
|
| 70 |
+
use_dropout (bool) -- if use dropout layers.
|
| 71 |
+
"""
|
| 72 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
| 73 |
+
self.outermost = outermost
|
| 74 |
+
if type(norm_layer) == functools.partial:
|
| 75 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 76 |
+
else:
|
| 77 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 78 |
+
if input_nc is None:
|
| 79 |
+
input_nc = outer_nc
|
| 80 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
| 81 |
+
stride=2, padding=1, bias=use_bias)
|
| 82 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
| 83 |
+
downnorm = norm_layer(inner_nc)
|
| 84 |
+
uprelu = nn.ReLU(True)
|
| 85 |
+
upnorm = norm_layer(outer_nc)
|
| 86 |
+
|
| 87 |
+
if outermost:
|
| 88 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
| 89 |
+
kernel_size=4, stride=2,
|
| 90 |
+
padding=1)
|
| 91 |
+
down = [downconv]
|
| 92 |
+
up = [uprelu, upconv, nn.Tanh()]
|
| 93 |
+
model = down + [submodule] + up
|
| 94 |
+
elif innermost:
|
| 95 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
| 96 |
+
kernel_size=4, stride=2,
|
| 97 |
+
padding=1, bias=use_bias)
|
| 98 |
+
down = [downrelu, downconv]
|
| 99 |
+
up = [uprelu, upconv, upnorm]
|
| 100 |
+
model = down + up
|
| 101 |
+
else:
|
| 102 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
| 103 |
+
kernel_size=4, stride=2,
|
| 104 |
+
padding=1, bias=use_bias)
|
| 105 |
+
down = [downrelu, downconv, downnorm]
|
| 106 |
+
up = [uprelu, upconv, upnorm]
|
| 107 |
+
|
| 108 |
+
if use_dropout:
|
| 109 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
| 110 |
+
else:
|
| 111 |
+
model = down + [submodule] + up
|
| 112 |
+
|
| 113 |
+
self.model = nn.Sequential(*model)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
if self.outermost:
|
| 117 |
+
return self.model(x).clamp(-1, 1)
|
| 118 |
+
else: # add skip connections
|
| 119 |
+
return torch.cat([x, self.model(x)], 1)
|
preprocessor/anime_segment.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from refnet.util import default
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Source code: https://github.com/SkyTNT/anime-segmentation?tab=readme-ov-file
|
| 9 |
+
Author: SkyTNT
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
class REBNCONV(nn.Module):
|
| 13 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
| 14 |
+
super(REBNCONV, self).__init__()
|
| 15 |
+
|
| 16 |
+
self.conv_s1 = nn.Conv2d(
|
| 17 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
| 18 |
+
)
|
| 19 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 20 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
hx = x
|
| 24 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
| 25 |
+
|
| 26 |
+
return xout
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
| 30 |
+
def _upsample_like(src, tar):
|
| 31 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
| 32 |
+
|
| 33 |
+
return src
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
### RSU-7 ###
|
| 37 |
+
class RSU7(nn.Module):
|
| 38 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
| 39 |
+
super(RSU7, self).__init__()
|
| 40 |
+
|
| 41 |
+
self.in_ch = in_ch
|
| 42 |
+
self.mid_ch = mid_ch
|
| 43 |
+
self.out_ch = out_ch
|
| 44 |
+
|
| 45 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
| 46 |
+
|
| 47 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 48 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 49 |
+
|
| 50 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 51 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 52 |
+
|
| 53 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 54 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 55 |
+
|
| 56 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 57 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 58 |
+
|
| 59 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 60 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 61 |
+
|
| 62 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 63 |
+
|
| 64 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 65 |
+
|
| 66 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 67 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 68 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 69 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 70 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 71 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
b, c, h, w = x.shape
|
| 75 |
+
|
| 76 |
+
hx = x
|
| 77 |
+
hxin = self.rebnconvin(hx)
|
| 78 |
+
|
| 79 |
+
hx1 = self.rebnconv1(hxin)
|
| 80 |
+
hx = self.pool1(hx1)
|
| 81 |
+
|
| 82 |
+
hx2 = self.rebnconv2(hx)
|
| 83 |
+
hx = self.pool2(hx2)
|
| 84 |
+
|
| 85 |
+
hx3 = self.rebnconv3(hx)
|
| 86 |
+
hx = self.pool3(hx3)
|
| 87 |
+
|
| 88 |
+
hx4 = self.rebnconv4(hx)
|
| 89 |
+
hx = self.pool4(hx4)
|
| 90 |
+
|
| 91 |
+
hx5 = self.rebnconv5(hx)
|
| 92 |
+
hx = self.pool5(hx5)
|
| 93 |
+
|
| 94 |
+
hx6 = self.rebnconv6(hx)
|
| 95 |
+
|
| 96 |
+
hx7 = self.rebnconv7(hx6)
|
| 97 |
+
|
| 98 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
| 99 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
| 100 |
+
|
| 101 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
| 102 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 103 |
+
|
| 104 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 105 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 106 |
+
|
| 107 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 108 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 109 |
+
|
| 110 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 111 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 112 |
+
|
| 113 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 114 |
+
|
| 115 |
+
return hx1d + hxin
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
### RSU-6 ###
|
| 119 |
+
class RSU6(nn.Module):
|
| 120 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 121 |
+
super(RSU6, self).__init__()
|
| 122 |
+
|
| 123 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 124 |
+
|
| 125 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 126 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 127 |
+
|
| 128 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 129 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 130 |
+
|
| 131 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 132 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 133 |
+
|
| 134 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 135 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 136 |
+
|
| 137 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 138 |
+
|
| 139 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 140 |
+
|
| 141 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 142 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 143 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 144 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 145 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
hx = x
|
| 149 |
+
|
| 150 |
+
hxin = self.rebnconvin(hx)
|
| 151 |
+
|
| 152 |
+
hx1 = self.rebnconv1(hxin)
|
| 153 |
+
hx = self.pool1(hx1)
|
| 154 |
+
|
| 155 |
+
hx2 = self.rebnconv2(hx)
|
| 156 |
+
hx = self.pool2(hx2)
|
| 157 |
+
|
| 158 |
+
hx3 = self.rebnconv3(hx)
|
| 159 |
+
hx = self.pool3(hx3)
|
| 160 |
+
|
| 161 |
+
hx4 = self.rebnconv4(hx)
|
| 162 |
+
hx = self.pool4(hx4)
|
| 163 |
+
|
| 164 |
+
hx5 = self.rebnconv5(hx)
|
| 165 |
+
|
| 166 |
+
hx6 = self.rebnconv6(hx5)
|
| 167 |
+
|
| 168 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
| 169 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 170 |
+
|
| 171 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 172 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 173 |
+
|
| 174 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 175 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 176 |
+
|
| 177 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 178 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 179 |
+
|
| 180 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 181 |
+
|
| 182 |
+
return hx1d + hxin
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
### RSU-5 ###
|
| 186 |
+
class RSU5(nn.Module):
|
| 187 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 188 |
+
super(RSU5, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 191 |
+
|
| 192 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 193 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 194 |
+
|
| 195 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 196 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 197 |
+
|
| 198 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 199 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 200 |
+
|
| 201 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 202 |
+
|
| 203 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 204 |
+
|
| 205 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 206 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 207 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 208 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
hx = x
|
| 212 |
+
|
| 213 |
+
hxin = self.rebnconvin(hx)
|
| 214 |
+
|
| 215 |
+
hx1 = self.rebnconv1(hxin)
|
| 216 |
+
hx = self.pool1(hx1)
|
| 217 |
+
|
| 218 |
+
hx2 = self.rebnconv2(hx)
|
| 219 |
+
hx = self.pool2(hx2)
|
| 220 |
+
|
| 221 |
+
hx3 = self.rebnconv3(hx)
|
| 222 |
+
hx = self.pool3(hx3)
|
| 223 |
+
|
| 224 |
+
hx4 = self.rebnconv4(hx)
|
| 225 |
+
|
| 226 |
+
hx5 = self.rebnconv5(hx4)
|
| 227 |
+
|
| 228 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
| 229 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 230 |
+
|
| 231 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 232 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 233 |
+
|
| 234 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 235 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 236 |
+
|
| 237 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 238 |
+
|
| 239 |
+
return hx1d + hxin
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
### RSU-4 ###
|
| 243 |
+
class RSU4(nn.Module):
|
| 244 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 245 |
+
super(RSU4, self).__init__()
|
| 246 |
+
|
| 247 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 248 |
+
|
| 249 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 250 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 251 |
+
|
| 252 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 253 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 254 |
+
|
| 255 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 256 |
+
|
| 257 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 258 |
+
|
| 259 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 260 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 261 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 262 |
+
|
| 263 |
+
def forward(self, x):
|
| 264 |
+
hx = x
|
| 265 |
+
|
| 266 |
+
hxin = self.rebnconvin(hx)
|
| 267 |
+
|
| 268 |
+
hx1 = self.rebnconv1(hxin)
|
| 269 |
+
hx = self.pool1(hx1)
|
| 270 |
+
|
| 271 |
+
hx2 = self.rebnconv2(hx)
|
| 272 |
+
hx = self.pool2(hx2)
|
| 273 |
+
|
| 274 |
+
hx3 = self.rebnconv3(hx)
|
| 275 |
+
|
| 276 |
+
hx4 = self.rebnconv4(hx3)
|
| 277 |
+
|
| 278 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 279 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 280 |
+
|
| 281 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 282 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 283 |
+
|
| 284 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 285 |
+
|
| 286 |
+
return hx1d + hxin
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
### RSU-4F ###
|
| 290 |
+
class RSU4F(nn.Module):
|
| 291 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 292 |
+
super(RSU4F, self).__init__()
|
| 293 |
+
|
| 294 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 295 |
+
|
| 296 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 297 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 298 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
| 299 |
+
|
| 300 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
| 301 |
+
|
| 302 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
| 303 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
| 304 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 305 |
+
|
| 306 |
+
def forward(self, x):
|
| 307 |
+
hx = x
|
| 308 |
+
|
| 309 |
+
hxin = self.rebnconvin(hx)
|
| 310 |
+
|
| 311 |
+
hx1 = self.rebnconv1(hxin)
|
| 312 |
+
hx2 = self.rebnconv2(hx1)
|
| 313 |
+
hx3 = self.rebnconv3(hx2)
|
| 314 |
+
|
| 315 |
+
hx4 = self.rebnconv4(hx3)
|
| 316 |
+
|
| 317 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 318 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
| 319 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
| 320 |
+
|
| 321 |
+
return hx1d + hxin
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class myrebnconv(nn.Module):
|
| 325 |
+
def __init__(
|
| 326 |
+
self,
|
| 327 |
+
in_ch=3,
|
| 328 |
+
out_ch=1,
|
| 329 |
+
kernel_size=3,
|
| 330 |
+
stride=1,
|
| 331 |
+
padding=1,
|
| 332 |
+
dilation=1,
|
| 333 |
+
groups=1,
|
| 334 |
+
):
|
| 335 |
+
super(myrebnconv, self).__init__()
|
| 336 |
+
|
| 337 |
+
self.conv = nn.Conv2d(
|
| 338 |
+
in_ch,
|
| 339 |
+
out_ch,
|
| 340 |
+
kernel_size=kernel_size,
|
| 341 |
+
stride=stride,
|
| 342 |
+
padding=padding,
|
| 343 |
+
dilation=dilation,
|
| 344 |
+
groups=groups,
|
| 345 |
+
)
|
| 346 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 347 |
+
self.rl = nn.ReLU(inplace=True)
|
| 348 |
+
|
| 349 |
+
def forward(self, x):
|
| 350 |
+
return self.rl(self.bn(self.conv(x)))
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class ISNetDIS(nn.Module):
|
| 354 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 355 |
+
super(ISNetDIS, self).__init__()
|
| 356 |
+
|
| 357 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
| 358 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 359 |
+
|
| 360 |
+
self.stage1 = RSU7(64, 32, 64)
|
| 361 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 362 |
+
|
| 363 |
+
self.stage2 = RSU6(64, 32, 128)
|
| 364 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 365 |
+
|
| 366 |
+
self.stage3 = RSU5(128, 64, 256)
|
| 367 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 368 |
+
|
| 369 |
+
self.stage4 = RSU4(256, 128, 512)
|
| 370 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 371 |
+
|
| 372 |
+
self.stage5 = RSU4F(512, 256, 512)
|
| 373 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 374 |
+
|
| 375 |
+
self.stage6 = RSU4F(512, 256, 512)
|
| 376 |
+
|
| 377 |
+
# decoder
|
| 378 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
| 379 |
+
self.stage4d = RSU4(1024, 128, 256)
|
| 380 |
+
self.stage3d = RSU5(512, 64, 128)
|
| 381 |
+
self.stage2d = RSU6(256, 32, 64)
|
| 382 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 383 |
+
|
| 384 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 385 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 386 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 387 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 388 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 389 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
hx = x
|
| 393 |
+
|
| 394 |
+
hxin = self.conv_in(hx)
|
| 395 |
+
hx = self.pool_in(hxin)
|
| 396 |
+
|
| 397 |
+
# stage 1
|
| 398 |
+
hx1 = self.stage1(hxin)
|
| 399 |
+
hx = self.pool12(hx1)
|
| 400 |
+
|
| 401 |
+
# stage 2
|
| 402 |
+
hx2 = self.stage2(hx)
|
| 403 |
+
hx = self.pool23(hx2)
|
| 404 |
+
|
| 405 |
+
# stage 3
|
| 406 |
+
hx3 = self.stage3(hx)
|
| 407 |
+
hx = self.pool34(hx3)
|
| 408 |
+
|
| 409 |
+
# stage 4
|
| 410 |
+
hx4 = self.stage4(hx)
|
| 411 |
+
hx = self.pool45(hx4)
|
| 412 |
+
|
| 413 |
+
# stage 5
|
| 414 |
+
hx5 = self.stage5(hx)
|
| 415 |
+
hx = self.pool56(hx5)
|
| 416 |
+
|
| 417 |
+
# stage 6
|
| 418 |
+
hx6 = self.stage6(hx)
|
| 419 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 420 |
+
|
| 421 |
+
# -------------------- decoder --------------------
|
| 422 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 423 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 424 |
+
|
| 425 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 426 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 427 |
+
|
| 428 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 429 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 430 |
+
|
| 431 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 432 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 433 |
+
|
| 434 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 435 |
+
|
| 436 |
+
# side output
|
| 437 |
+
d1 = self.side1(hx1d)
|
| 438 |
+
d1 = _upsample_like(d1, x)
|
| 439 |
+
|
| 440 |
+
# d2 = self.side2(hx2d)
|
| 441 |
+
# d2 = _upsample_like(d2, x)
|
| 442 |
+
#
|
| 443 |
+
# d3 = self.side3(hx3d)
|
| 444 |
+
# d3 = _upsample_like(d3, x)
|
| 445 |
+
#
|
| 446 |
+
# d4 = self.side4(hx4d)
|
| 447 |
+
# d4 = _upsample_like(d4, x)
|
| 448 |
+
#
|
| 449 |
+
# d5 = self.side5(hx5d)
|
| 450 |
+
# d5 = _upsample_like(d5, x)
|
| 451 |
+
#
|
| 452 |
+
# d6 = self.side6(hx6)
|
| 453 |
+
# d6 = _upsample_like(d6, x)
|
| 454 |
+
|
| 455 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
| 456 |
+
#
|
| 457 |
+
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
| 458 |
+
# return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
| 459 |
+
return torch.sigmoid(d1)
|
| 460 |
+
|
| 461 |
+
def proceed(self, x: torch.Tensor, th=None, tw=None, s=1024, dilate=False, crop=True, *args, **kwargs):
|
| 462 |
+
b, c, h, w = x.shape
|
| 463 |
+
|
| 464 |
+
if crop:
|
| 465 |
+
th, tw = default(th, h), default(tw, w)
|
| 466 |
+
scale = s / max(h, w)
|
| 467 |
+
h, w = int(h * scale), int(w * scale)
|
| 468 |
+
|
| 469 |
+
canvas = -torch.ones((b, c, s, s), dtype=x.dtype, device=x.device)
|
| 470 |
+
ph, pw = (s - h) // 2, (s - w) // 2
|
| 471 |
+
x = F.interpolate(x, scale_factor=scale, mode="bicubic")
|
| 472 |
+
|
| 473 |
+
canvas[:, :, ph: ph+h, pw: pw+w] = x
|
| 474 |
+
|
| 475 |
+
canvas = 1 - (canvas + 1.) / 2.
|
| 476 |
+
mask = self(canvas)[:, :, ph: ph+h, pw: pw+w]
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
x = F.interpolate(x, size=(s, s), mode="bicubic")
|
| 480 |
+
mask = self(x)
|
| 481 |
+
|
| 482 |
+
mask = F.interpolate(mask, (th, tw), mode="bicubic").clamp(0, 1)
|
| 483 |
+
|
| 484 |
+
if dilate:
|
| 485 |
+
mask = F.max_pool2d(mask, kernel_size=21, stride=1, padding=10)
|
| 486 |
+
# mask = mask_expansion(mask, 32, 20)
|
| 487 |
+
return mask
|
preprocessor/manga_line_extractor.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class _bn_relu_conv(nn.Module):
|
| 6 |
+
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
|
| 7 |
+
super(_bn_relu_conv, self).__init__()
|
| 8 |
+
self.model = nn.Sequential(
|
| 9 |
+
nn.BatchNorm2d(in_filters, eps=1e-3),
|
| 10 |
+
nn.LeakyReLU(0.2),
|
| 11 |
+
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.model(x)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _u_bn_relu_conv(nn.Module):
|
| 19 |
+
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
|
| 20 |
+
super(_u_bn_relu_conv, self).__init__()
|
| 21 |
+
self.model = nn.Sequential(
|
| 22 |
+
nn.BatchNorm2d(in_filters, eps=1e-3),
|
| 23 |
+
nn.LeakyReLU(0.2),
|
| 24 |
+
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
|
| 25 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return self.model(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _shortcut(nn.Module):
|
| 34 |
+
def __init__(self, in_filters, nb_filters, subsample=1):
|
| 35 |
+
super(_shortcut, self).__init__()
|
| 36 |
+
self.process = False
|
| 37 |
+
self.model = None
|
| 38 |
+
if in_filters != nb_filters or subsample != 1:
|
| 39 |
+
self.process = True
|
| 40 |
+
self.model = nn.Sequential(
|
| 41 |
+
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, y):
|
| 45 |
+
#print(x.size(), y.size(), self.process)
|
| 46 |
+
if self.process:
|
| 47 |
+
y0 = self.model(x)
|
| 48 |
+
#print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
|
| 49 |
+
return y0 + y
|
| 50 |
+
else:
|
| 51 |
+
#print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
|
| 52 |
+
return x + y
|
| 53 |
+
|
| 54 |
+
class _u_shortcut(nn.Module):
|
| 55 |
+
def __init__(self, in_filters, nb_filters, subsample):
|
| 56 |
+
super(_u_shortcut, self).__init__()
|
| 57 |
+
self.process = False
|
| 58 |
+
self.model = None
|
| 59 |
+
if in_filters != nb_filters:
|
| 60 |
+
self.process = True
|
| 61 |
+
self.model = nn.Sequential(
|
| 62 |
+
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
|
| 63 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, y):
|
| 67 |
+
if self.process:
|
| 68 |
+
return self.model(x) + y
|
| 69 |
+
else:
|
| 70 |
+
return x + y
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class basic_block(nn.Module):
|
| 74 |
+
def __init__(self, in_filters, nb_filters, init_subsample=1):
|
| 75 |
+
super(basic_block, self).__init__()
|
| 76 |
+
self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
|
| 77 |
+
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
|
| 78 |
+
self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x1 = self.conv1(x)
|
| 82 |
+
x2 = self.residual(x1)
|
| 83 |
+
return self.shortcut(x, x2)
|
| 84 |
+
|
| 85 |
+
class _u_basic_block(nn.Module):
|
| 86 |
+
def __init__(self, in_filters, nb_filters, init_subsample=1):
|
| 87 |
+
super(_u_basic_block, self).__init__()
|
| 88 |
+
self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
|
| 89 |
+
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
|
| 90 |
+
self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
y = self.residual(self.conv1(x))
|
| 94 |
+
return self.shortcut(x, y)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class _residual_block(nn.Module):
|
| 98 |
+
def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
|
| 99 |
+
super(_residual_block, self).__init__()
|
| 100 |
+
layers = []
|
| 101 |
+
for i in range(repetitions):
|
| 102 |
+
init_subsample = 1
|
| 103 |
+
if i == repetitions - 1 and not is_first_layer:
|
| 104 |
+
init_subsample = 2
|
| 105 |
+
if i == 0:
|
| 106 |
+
l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
|
| 107 |
+
else:
|
| 108 |
+
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
|
| 109 |
+
layers.append(l)
|
| 110 |
+
|
| 111 |
+
self.model = nn.Sequential(*layers)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return self.model(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class _upsampling_residual_block(nn.Module):
|
| 118 |
+
def __init__(self, in_filters, nb_filters, repetitions):
|
| 119 |
+
super(_upsampling_residual_block, self).__init__()
|
| 120 |
+
layers = []
|
| 121 |
+
for i in range(repetitions):
|
| 122 |
+
l = None
|
| 123 |
+
if i == 0:
|
| 124 |
+
l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
|
| 125 |
+
else:
|
| 126 |
+
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
|
| 127 |
+
layers.append(l)
|
| 128 |
+
|
| 129 |
+
self.model = nn.Sequential(*layers)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
return self.model(x)
|
| 133 |
+
|
| 134 |
+
class res_skip(nn.Module):
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
super(res_skip, self).__init__()
|
| 138 |
+
self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) # (input)
|
| 139 |
+
self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) # (block0)
|
| 140 |
+
self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) # (block1)
|
| 141 |
+
self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) # (block2)
|
| 142 |
+
self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) # (block3)
|
| 143 |
+
|
| 144 |
+
self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) # (block4)
|
| 145 |
+
self.res1 = _shortcut(in_filters=192, nb_filters=192) # (block3, block5, subsample=(1,1))
|
| 146 |
+
|
| 147 |
+
self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) # (res1)
|
| 148 |
+
self.res2 = _shortcut(in_filters=96, nb_filters=96) # (block2, block6, subsample=(1,1))
|
| 149 |
+
|
| 150 |
+
self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) # (res2)
|
| 151 |
+
self.res3 = _shortcut(in_filters=48, nb_filters=48) # (block1, block7, subsample=(1,1))
|
| 152 |
+
|
| 153 |
+
self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) # (res3)
|
| 154 |
+
self.res4 = _shortcut(in_filters=24, nb_filters=24) # (block0,block8, subsample=(1,1))
|
| 155 |
+
|
| 156 |
+
self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) # (res4)
|
| 157 |
+
self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) # (block7)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
x0 = self.block0(x)
|
| 161 |
+
x1 = self.block1(x0)
|
| 162 |
+
x2 = self.block2(x1)
|
| 163 |
+
x3 = self.block3(x2)
|
| 164 |
+
x4 = self.block4(x3)
|
| 165 |
+
|
| 166 |
+
x5 = self.block5(x4)
|
| 167 |
+
res1 = self.res1(x3, x5)
|
| 168 |
+
|
| 169 |
+
x6 = self.block6(res1)
|
| 170 |
+
res2 = self.res2(x2, x6)
|
| 171 |
+
|
| 172 |
+
x7 = self.block7(res2)
|
| 173 |
+
res3 = self.res3(x1, x7)
|
| 174 |
+
|
| 175 |
+
x8 = self.block8(res3)
|
| 176 |
+
res4 = self.res4(x0, x8)
|
| 177 |
+
|
| 178 |
+
x9 = self.block9(res4)
|
| 179 |
+
y = self.conv15(x9)
|
| 180 |
+
|
| 181 |
+
return y
|
| 182 |
+
|
| 183 |
+
def proceed(self, sketch):
|
| 184 |
+
sketch = transforms.ToTensor()(sketch).unsqueeze(0)[:, 0] * 255
|
| 185 |
+
sketch = sketch.unsqueeze(1).cuda()
|
| 186 |
+
sketch = self(sketch) / 127.5 - 1
|
| 187 |
+
return -sketch.clamp(-1, 1)
|
preprocessor/sk_model.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torchvision.transforms.functional as tf
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
norm_layer = nn.InstanceNorm2d
|
| 6 |
+
|
| 7 |
+
class ResidualBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_features):
|
| 9 |
+
super(ResidualBlock, self).__init__()
|
| 10 |
+
|
| 11 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
| 12 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 13 |
+
norm_layer(in_features),
|
| 14 |
+
nn.ReLU(inplace=True),
|
| 15 |
+
nn.ReflectionPad2d(1),
|
| 16 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 17 |
+
norm_layer(in_features)
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return x + self.conv_block(x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Generator(nn.Module):
|
| 27 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
| 28 |
+
super(Generator, self).__init__()
|
| 29 |
+
|
| 30 |
+
# Initial convolution block
|
| 31 |
+
model0 = [ nn.ReflectionPad2d(3),
|
| 32 |
+
nn.Conv2d(input_nc, 64, 7),
|
| 33 |
+
norm_layer(64),
|
| 34 |
+
nn.ReLU(inplace=True) ]
|
| 35 |
+
self.model0 = nn.Sequential(*model0)
|
| 36 |
+
|
| 37 |
+
# Downsampling
|
| 38 |
+
model1 = []
|
| 39 |
+
in_features = 64
|
| 40 |
+
out_features = in_features*2
|
| 41 |
+
for _ in range(2):
|
| 42 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
| 43 |
+
norm_layer(out_features),
|
| 44 |
+
nn.ReLU(inplace=True) ]
|
| 45 |
+
in_features = out_features
|
| 46 |
+
out_features = in_features*2
|
| 47 |
+
self.model1 = nn.Sequential(*model1)
|
| 48 |
+
|
| 49 |
+
model2 = []
|
| 50 |
+
# Residual blocks
|
| 51 |
+
for _ in range(n_residual_blocks):
|
| 52 |
+
model2 += [ResidualBlock(in_features)]
|
| 53 |
+
self.model2 = nn.Sequential(*model2)
|
| 54 |
+
|
| 55 |
+
# Upsampling
|
| 56 |
+
model3 = []
|
| 57 |
+
out_features = in_features//2
|
| 58 |
+
for _ in range(2):
|
| 59 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
| 60 |
+
norm_layer(out_features),
|
| 61 |
+
nn.ReLU(inplace=True) ]
|
| 62 |
+
in_features = out_features
|
| 63 |
+
out_features = in_features//2
|
| 64 |
+
self.model3 = nn.Sequential(*model3)
|
| 65 |
+
|
| 66 |
+
# Output layer
|
| 67 |
+
model4 = [ nn.ReflectionPad2d(3),
|
| 68 |
+
nn.Conv2d(64, output_nc, 7)]
|
| 69 |
+
if sigmoid:
|
| 70 |
+
model4 += [nn.Sigmoid()]
|
| 71 |
+
|
| 72 |
+
self.model4 = nn.Sequential(*model4)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, cond=None):
|
| 75 |
+
out = self.model0(x)
|
| 76 |
+
out = self.model1(out)
|
| 77 |
+
out = self.model2(out)
|
| 78 |
+
out = self.model3(out)
|
| 79 |
+
out = self.model4(out)
|
| 80 |
+
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class LineartDetector(nn.Module):
|
| 85 |
+
def __init__(self):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.model = Generator(3, 1, 3)
|
| 88 |
+
|
| 89 |
+
def load_state_dict(self, sd):
|
| 90 |
+
self.model.load_state_dict(sd)
|
| 91 |
+
|
| 92 |
+
def proceed(self, sketch):
|
| 93 |
+
sketch = tf.pil_to_tensor(sketch).unsqueeze(0).cuda().float()
|
| 94 |
+
return -self.model(sketch)
|
preprocessor/sketchKeras.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def postprocess(pred, thresh=0.18):
|
| 9 |
+
assert thresh <= 1.0 and thresh >= 0.0
|
| 10 |
+
|
| 11 |
+
pred = torch.amax(pred, 0)
|
| 12 |
+
pred[pred < thresh] = 0
|
| 13 |
+
pred -= 0.5
|
| 14 |
+
pred *= 2
|
| 15 |
+
return pred
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SketchKeras(nn.Module):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super(SketchKeras, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.downblock_1 = nn.Sequential(
|
| 23 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 24 |
+
nn.Conv2d(1, 32, kernel_size=3, stride=1),
|
| 25 |
+
nn.BatchNorm2d(32, eps=1e-3, momentum=0),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
)
|
| 28 |
+
self.downblock_2 = nn.Sequential(
|
| 29 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 30 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 31 |
+
nn.BatchNorm2d(64, eps=1e-3, momentum=0),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 34 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 35 |
+
nn.BatchNorm2d(64, eps=1e-3, momentum=0),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
)
|
| 38 |
+
self.downblock_3 = nn.Sequential(
|
| 39 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 40 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2),
|
| 41 |
+
nn.BatchNorm2d(128, eps=1e-3, momentum=0),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 44 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=1),
|
| 45 |
+
nn.BatchNorm2d(128, eps=1e-3, momentum=0),
|
| 46 |
+
nn.ReLU(),
|
| 47 |
+
)
|
| 48 |
+
self.downblock_4 = nn.Sequential(
|
| 49 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 50 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2),
|
| 51 |
+
nn.BatchNorm2d(256, eps=1e-3, momentum=0),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 54 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1),
|
| 55 |
+
nn.BatchNorm2d(256, eps=1e-3, momentum=0),
|
| 56 |
+
nn.ReLU(),
|
| 57 |
+
)
|
| 58 |
+
self.downblock_5 = nn.Sequential(
|
| 59 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 60 |
+
nn.Conv2d(256, 512, kernel_size=4, stride=2),
|
| 61 |
+
nn.BatchNorm2d(512, eps=1e-3, momentum=0),
|
| 62 |
+
nn.ReLU(),
|
| 63 |
+
)
|
| 64 |
+
self.downblock_6 = nn.Sequential(
|
| 65 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 66 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1),
|
| 67 |
+
nn.BatchNorm2d(512, eps=1e-3, momentum=0),
|
| 68 |
+
nn.ReLU(),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.upblock_1 = nn.Sequential(
|
| 72 |
+
nn.Upsample(scale_factor=2, mode="bicubic"),
|
| 73 |
+
nn.ReflectionPad2d((1, 2, 1, 2)),
|
| 74 |
+
nn.Conv2d(1024, 512, kernel_size=4, stride=1),
|
| 75 |
+
nn.BatchNorm2d(512, eps=1e-3, momentum=0),
|
| 76 |
+
nn.ReLU(),
|
| 77 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 78 |
+
nn.Conv2d(512, 256, kernel_size=3, stride=1),
|
| 79 |
+
nn.BatchNorm2d(256, eps=1e-3, momentum=0),
|
| 80 |
+
nn.ReLU(),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.upblock_2 = nn.Sequential(
|
| 84 |
+
nn.Upsample(scale_factor=2, mode="bicubic"),
|
| 85 |
+
nn.ReflectionPad2d((1, 2, 1, 2)),
|
| 86 |
+
nn.Conv2d(512, 256, kernel_size=4, stride=1),
|
| 87 |
+
nn.BatchNorm2d(256, eps=1e-3, momentum=0),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 90 |
+
nn.Conv2d(256, 128, kernel_size=3, stride=1),
|
| 91 |
+
nn.BatchNorm2d(128, eps=1e-3, momentum=0),
|
| 92 |
+
nn.ReLU(),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.upblock_3 = nn.Sequential(
|
| 96 |
+
nn.Upsample(scale_factor=2, mode="bicubic"),
|
| 97 |
+
nn.ReflectionPad2d((1, 2, 1, 2)),
|
| 98 |
+
nn.Conv2d(256, 128, kernel_size=4, stride=1),
|
| 99 |
+
nn.BatchNorm2d(128, eps=1e-3, momentum=0),
|
| 100 |
+
nn.ReLU(),
|
| 101 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 102 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1),
|
| 103 |
+
nn.BatchNorm2d(64, eps=1e-3, momentum=0),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.upblock_4 = nn.Sequential(
|
| 108 |
+
nn.Upsample(scale_factor=2, mode="bicubic"),
|
| 109 |
+
nn.ReflectionPad2d((1, 2, 1, 2)),
|
| 110 |
+
nn.Conv2d(128, 64, kernel_size=4, stride=1),
|
| 111 |
+
nn.BatchNorm2d(64, eps=1e-3, momentum=0),
|
| 112 |
+
nn.ReLU(),
|
| 113 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 114 |
+
nn.Conv2d(64, 32, kernel_size=3, stride=1),
|
| 115 |
+
nn.BatchNorm2d(32, eps=1e-3, momentum=0),
|
| 116 |
+
nn.ReLU(),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.last_pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 120 |
+
self.last_conv = nn.Conv2d(64, 1, kernel_size=3, stride=1)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
d1 = self.downblock_1(x)
|
| 124 |
+
d2 = self.downblock_2(d1)
|
| 125 |
+
d3 = self.downblock_3(d2)
|
| 126 |
+
d4 = self.downblock_4(d3)
|
| 127 |
+
d5 = self.downblock_5(d4)
|
| 128 |
+
d6 = self.downblock_6(d5)
|
| 129 |
+
|
| 130 |
+
u1 = torch.cat((d5, d6), dim=1)
|
| 131 |
+
u1 = self.upblock_1(u1)
|
| 132 |
+
u2 = torch.cat((d4, u1), dim=1)
|
| 133 |
+
u2 = self.upblock_2(u2)
|
| 134 |
+
u3 = torch.cat((d3, u2), dim=1)
|
| 135 |
+
u3 = self.upblock_3(u3)
|
| 136 |
+
u4 = torch.cat((d2, u3), dim=1)
|
| 137 |
+
u4 = self.upblock_4(u4)
|
| 138 |
+
u5 = torch.cat((d1, u4), dim=1)
|
| 139 |
+
|
| 140 |
+
out = self.last_conv(self.last_pad(u5))
|
| 141 |
+
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
def proceed(self, img):
|
| 145 |
+
img = np.array(img)
|
| 146 |
+
blurred = cv2.GaussianBlur(img, (0, 0), 3)
|
| 147 |
+
img = img.astype(int) - blurred.astype(int)
|
| 148 |
+
img = img.astype(np.float32) / 127.5
|
| 149 |
+
img /= np.max(img)
|
| 150 |
+
img = torch.tensor(img).unsqueeze(0).permute(3, 0, 1, 2).cuda()
|
| 151 |
+
img = self(img)
|
| 152 |
+
img = postprocess(img, thresh=0.1).unsqueeze(1)
|
| 153 |
+
return img
|
refnet/__init__.py
ADDED
|
File without changes
|
refnet/ldm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .ddpm import LatentDiffusion
|
refnet/ldm/ddpm.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
wild mixture of
|
| 3 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 4 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| 5 |
+
https://github.com/CompVis/taming-transformers
|
| 6 |
+
-- merci
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
from contextlib import contextmanager
|
| 13 |
+
from functools import partial
|
| 14 |
+
|
| 15 |
+
from refnet.util import default, count_params, instantiate_from_config, exists
|
| 16 |
+
from refnet.ldm.util import make_beta_schedule, extract_into_tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def disabled_train(self, mode=True):
|
| 21 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 22 |
+
does not change anymore."""
|
| 23 |
+
return self
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def uniform_on_device(r1, r2, shape, device):
|
| 27 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def rescale_zero_terminal_snr(betas):
|
| 31 |
+
"""
|
| 32 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
betas (`torch.FloatTensor`):
|
| 37 |
+
the betas that the scheduler is being initialized with.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
| 41 |
+
"""
|
| 42 |
+
# Convert betas to alphas_bar_sqrt
|
| 43 |
+
alphas = 1.0 - betas
|
| 44 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 45 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 46 |
+
|
| 47 |
+
# Store old values.
|
| 48 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 49 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 50 |
+
|
| 51 |
+
# Shift so the last timestep is zero.
|
| 52 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 53 |
+
|
| 54 |
+
# Scale so the first timestep is back to the old value.
|
| 55 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 56 |
+
|
| 57 |
+
# Convert alphas_bar_sqrt to betas
|
| 58 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 59 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 60 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 61 |
+
betas = 1 - alphas
|
| 62 |
+
|
| 63 |
+
return betas
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DDPM(nn.Module):
|
| 67 |
+
# classic DDPM with Gaussian diffusion, in image space
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
unet_config,
|
| 71 |
+
timesteps = 1000,
|
| 72 |
+
beta_schedule = "scaled_linear",
|
| 73 |
+
image_size = 256,
|
| 74 |
+
channels = 3,
|
| 75 |
+
linear_start = 1e-4,
|
| 76 |
+
linear_end = 2e-2,
|
| 77 |
+
cosine_s = 8e-3,
|
| 78 |
+
v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
| 79 |
+
parameterization = "eps", # all assuming fixed variance schedules
|
| 80 |
+
zero_snr = False,
|
| 81 |
+
half_precision_dtype = "float16",
|
| 82 |
+
version = "sdv1",
|
| 83 |
+
*args,
|
| 84 |
+
**kwargs
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'"
|
| 88 |
+
assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default"
|
| 89 |
+
if zero_snr:
|
| 90 |
+
assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.'
|
| 91 |
+
|
| 92 |
+
self.is_sdxl = (version == "sdxl")
|
| 93 |
+
self.parameterization = parameterization
|
| 94 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
| 95 |
+
self.cond_stage_model = None
|
| 96 |
+
self.img_embedder = None
|
| 97 |
+
self.image_size = image_size # try conv?
|
| 98 |
+
self.channels = channels
|
| 99 |
+
self.model = DiffusionWrapper(unet_config)
|
| 100 |
+
count_params(self.model, verbose=True)
|
| 101 |
+
self.v_posterior = v_posterior
|
| 102 |
+
self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16
|
| 103 |
+
self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
|
| 104 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000,
|
| 108 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
|
| 109 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
| 110 |
+
cosine_s=cosine_s, zero_snr=zero_snr)
|
| 111 |
+
|
| 112 |
+
alphas = 1. - betas
|
| 113 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 114 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 115 |
+
|
| 116 |
+
timesteps, = betas.shape
|
| 117 |
+
self.num_timesteps = int(timesteps)
|
| 118 |
+
self.linear_start = linear_start
|
| 119 |
+
self.linear_end = linear_end
|
| 120 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| 121 |
+
|
| 122 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 123 |
+
|
| 124 |
+
self.register_buffer('betas', to_torch(betas))
|
| 125 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 126 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 127 |
+
|
| 128 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 129 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 130 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 131 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 132 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 133 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 134 |
+
|
| 135 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 136 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
| 137 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
| 138 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 139 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
| 140 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 141 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
| 142 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
| 143 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
| 144 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
| 145 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@contextmanager
|
| 149 |
+
def ema_scope(self, context=None):
|
| 150 |
+
if self.use_ema:
|
| 151 |
+
self.model_ema.store(self.model.parameters())
|
| 152 |
+
self.model_ema.copy_to(self.model)
|
| 153 |
+
if context is not None:
|
| 154 |
+
print(f"{context}: Switched to EMA weights")
|
| 155 |
+
try:
|
| 156 |
+
yield None
|
| 157 |
+
finally:
|
| 158 |
+
if self.use_ema:
|
| 159 |
+
self.model_ema.restore(self.model.parameters())
|
| 160 |
+
if context is not None:
|
| 161 |
+
print(f"{context}: Restored training weights")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
| 165 |
+
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 166 |
+
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 167 |
+
return (
|
| 168 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
| 169 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def add_noise(self, x_start, t, noise=None):
|
| 173 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 174 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 175 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype)
|
| 176 |
+
|
| 177 |
+
def get_v(self, x, noise, t):
|
| 178 |
+
return (
|
| 179 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
| 180 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def normalize_timesteps(self, timesteps):
|
| 184 |
+
return timesteps
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class LatentDiffusion(DDPM):
|
| 188 |
+
"""main class"""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
first_stage_config,
|
| 193 |
+
cond_stage_config,
|
| 194 |
+
scale_factor = 1.0,
|
| 195 |
+
*args,
|
| 196 |
+
**kwargs
|
| 197 |
+
):
|
| 198 |
+
super().__init__(*args, **kwargs)
|
| 199 |
+
self.scale_factor = scale_factor
|
| 200 |
+
self.first_stage_model, self.cond_stage_model = map(
|
| 201 |
+
lambda t: instantiate_from_config(t).eval().requires_grad_(False),
|
| 202 |
+
(first_stage_config, cond_stage_config)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def get_first_stage_encoding(self, x):
|
| 207 |
+
encoder_posterior = self.first_stage_model.encode(x)
|
| 208 |
+
z = encoder_posterior.sample() * self.scale_factor
|
| 209 |
+
return z.to(self.dtype).detach()
|
| 210 |
+
|
| 211 |
+
@torch.no_grad()
|
| 212 |
+
def decode_first_stage(self, z):
|
| 213 |
+
z = 1. / self.scale_factor * z
|
| 214 |
+
return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach()
|
| 215 |
+
|
| 216 |
+
def apply_model(self, x_noisy, t, cond):
|
| 217 |
+
return self.model(x_noisy, t, **cond)
|
| 218 |
+
|
| 219 |
+
def get_learned_embedding(self, c, *args, **kwargs):
|
| 220 |
+
wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs))
|
| 221 |
+
clip_emb = self.cond_stage_model.encode(c, **kwargs).detach()
|
| 222 |
+
return wd_emb, wd_logits, clip_emb
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class DiffusionWrapper(nn.Module):
|
| 226 |
+
def __init__(self, diff_model_config):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
| 229 |
+
|
| 230 |
+
def forward(self, x, t, **cond):
|
| 231 |
+
for k in cond:
|
| 232 |
+
if k in ["context", "y", "concat"]:
|
| 233 |
+
cond[k] = torch.cat(cond[k], 1)
|
| 234 |
+
|
| 235 |
+
out = self.diffusion_model(x, t, **cond)
|
| 236 |
+
return out
|
refnet/ldm/openaimodel.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch as th
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from refnet.ldm.util import (
|
| 10 |
+
conv_nd,
|
| 11 |
+
linear,
|
| 12 |
+
avg_pool_nd,
|
| 13 |
+
zero_module,
|
| 14 |
+
normalization,
|
| 15 |
+
timestep_embedding,
|
| 16 |
+
)
|
| 17 |
+
from refnet.util import checkpoint_wrapper
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# dummy replace
|
| 22 |
+
def convert_module_to_f16(x):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def convert_module_to_f32(x):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## go
|
| 30 |
+
class AttentionPool2d(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
spacial_dim: int,
|
| 38 |
+
embed_dim: int,
|
| 39 |
+
num_heads_channels: int,
|
| 40 |
+
output_dim: int = None,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
| 44 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
| 45 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
| 46 |
+
self.num_heads = embed_dim // num_heads_channels
|
| 47 |
+
self.attention = QKVAttention(self.num_heads)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
b, c, *_spatial = x.shape
|
| 51 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
| 52 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
| 53 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
| 54 |
+
x = self.qkv_proj(x)
|
| 55 |
+
x = self.attention(x)
|
| 56 |
+
x = self.c_proj(x)
|
| 57 |
+
return x[:, :, 0]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TimestepBlock(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def forward(self, x, emb):
|
| 67 |
+
"""
|
| 68 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Upsample(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
An upsampling layer with an optional convolution.
|
| 75 |
+
:param channels: channels in the inputs and outputs.
|
| 76 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 77 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 78 |
+
upsampling occurs in the inner-two dimensions.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.channels = channels
|
| 84 |
+
self.out_channels = out_channels or channels
|
| 85 |
+
self.use_conv = use_conv
|
| 86 |
+
self.dims = dims
|
| 87 |
+
if use_conv:
|
| 88 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
assert x.shape[1] == self.channels
|
| 92 |
+
if self.dims == 3:
|
| 93 |
+
x = F.interpolate(
|
| 94 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 98 |
+
if self.use_conv:
|
| 99 |
+
x = self.conv(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
class TransposedUpsample(nn.Module):
|
| 103 |
+
'Learned 2x upsampling without padding'
|
| 104 |
+
def __init__(self, channels, out_channels=None, ks=5):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.channels = channels
|
| 107 |
+
self.out_channels = out_channels or channels
|
| 108 |
+
|
| 109 |
+
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
| 110 |
+
|
| 111 |
+
def forward(self,x):
|
| 112 |
+
return self.up(x)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Downsample(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
A downsampling layer with an optional convolution.
|
| 118 |
+
:param channels: channels in the inputs and outputs.
|
| 119 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 120 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 121 |
+
downsampling occurs in the inner-two dimensions.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.channels = channels
|
| 127 |
+
self.out_channels = out_channels or channels
|
| 128 |
+
self.use_conv = use_conv
|
| 129 |
+
self.dims = dims
|
| 130 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 131 |
+
if use_conv:
|
| 132 |
+
self.op = conv_nd(
|
| 133 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
assert self.channels == self.out_channels
|
| 137 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
assert x.shape[1] == self.channels
|
| 141 |
+
return self.op(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ResBlock(TimestepBlock):
|
| 145 |
+
"""
|
| 146 |
+
A residual block that can optionally change the number of channels.
|
| 147 |
+
:param channels: the number of input channels.
|
| 148 |
+
:param emb_channels: the number of timestep embedding channels.
|
| 149 |
+
:param dropout: the rate of dropout.
|
| 150 |
+
:param out_channels: if specified, the number of out channels.
|
| 151 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
| 152 |
+
convolution instead of a smaller 1x1 convolution to change the
|
| 153 |
+
channels in the skip connection.
|
| 154 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 155 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 156 |
+
:param up: if True, use this block for upsampling.
|
| 157 |
+
:param down: if True, use this block for downsampling.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
channels,
|
| 163 |
+
emb_channels,
|
| 164 |
+
dropout,
|
| 165 |
+
out_channels=None,
|
| 166 |
+
use_conv=False,
|
| 167 |
+
use_scale_shift_norm=False,
|
| 168 |
+
dims=2,
|
| 169 |
+
use_checkpoint=False,
|
| 170 |
+
up=False,
|
| 171 |
+
down=False,
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.channels = channels
|
| 175 |
+
self.emb_channels = emb_channels
|
| 176 |
+
self.dropout = dropout
|
| 177 |
+
self.out_channels = out_channels or channels
|
| 178 |
+
self.use_conv = use_conv
|
| 179 |
+
self.checkpoint = use_checkpoint
|
| 180 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 181 |
+
|
| 182 |
+
self.in_layers = nn.Sequential(
|
| 183 |
+
normalization(channels),
|
| 184 |
+
nn.SiLU(),
|
| 185 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.updown = up or down
|
| 189 |
+
|
| 190 |
+
if up:
|
| 191 |
+
self.h_upd = Upsample(channels, False, dims)
|
| 192 |
+
self.x_upd = Upsample(channels, False, dims)
|
| 193 |
+
elif down:
|
| 194 |
+
self.h_upd = Downsample(channels, False, dims)
|
| 195 |
+
self.x_upd = Downsample(channels, False, dims)
|
| 196 |
+
else:
|
| 197 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 198 |
+
|
| 199 |
+
self.emb_layers = nn.Sequential(
|
| 200 |
+
nn.SiLU(),
|
| 201 |
+
linear(
|
| 202 |
+
emb_channels,
|
| 203 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
self.out_layers = nn.Sequential(
|
| 207 |
+
normalization(self.out_channels),
|
| 208 |
+
nn.SiLU(),
|
| 209 |
+
nn.Dropout(p=dropout),
|
| 210 |
+
zero_module(
|
| 211 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
| 212 |
+
),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if self.out_channels == channels:
|
| 216 |
+
self.skip_connection = nn.Identity()
|
| 217 |
+
elif use_conv:
|
| 218 |
+
self.skip_connection = conv_nd(
|
| 219 |
+
dims, channels, self.out_channels, 3, padding=1
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 223 |
+
|
| 224 |
+
@checkpoint_wrapper
|
| 225 |
+
def forward(self, x, emb):
|
| 226 |
+
if self.updown:
|
| 227 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 228 |
+
h = in_rest(x)
|
| 229 |
+
h = self.h_upd(h)
|
| 230 |
+
x = self.x_upd(x)
|
| 231 |
+
h = in_conv(h)
|
| 232 |
+
else:
|
| 233 |
+
h = self.in_layers(x)
|
| 234 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 235 |
+
while len(emb_out.shape) < len(h.shape):
|
| 236 |
+
emb_out = emb_out[..., None]
|
| 237 |
+
if self.use_scale_shift_norm:
|
| 238 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 239 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
| 240 |
+
h = out_norm(h) * (1 + scale) + shift
|
| 241 |
+
h = out_rest(h)
|
| 242 |
+
else:
|
| 243 |
+
h = h + emb_out
|
| 244 |
+
h = self.out_layers(h)
|
| 245 |
+
return self.skip_connection(x) + h
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class AttentionBlock(nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
An attention block that allows spatial positions to attend to each other.
|
| 251 |
+
Originally ported from here, but adapted to the N-d case.
|
| 252 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
channels,
|
| 258 |
+
num_heads=1,
|
| 259 |
+
num_head_channels=-1,
|
| 260 |
+
use_checkpoint=False,
|
| 261 |
+
use_new_attention_order=False,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.channels = channels
|
| 265 |
+
if num_head_channels == -1:
|
| 266 |
+
self.num_heads = num_heads
|
| 267 |
+
else:
|
| 268 |
+
assert (
|
| 269 |
+
channels % num_head_channels == 0
|
| 270 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 271 |
+
self.num_heads = channels // num_head_channels
|
| 272 |
+
self.use_checkpoint = use_checkpoint
|
| 273 |
+
self.norm = normalization(channels)
|
| 274 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 275 |
+
if use_new_attention_order:
|
| 276 |
+
# split qkv before split heads
|
| 277 |
+
self.attention = QKVAttention(self.num_heads)
|
| 278 |
+
else:
|
| 279 |
+
# split heads before split qkv
|
| 280 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
| 281 |
+
|
| 282 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
| 283 |
+
|
| 284 |
+
@checkpoint_wrapper
|
| 285 |
+
def forward(self, x):
|
| 286 |
+
b, c, *spatial = x.shape
|
| 287 |
+
x = x.reshape(b, c, -1)
|
| 288 |
+
qkv = self.qkv(self.norm(x))
|
| 289 |
+
h = self.attention(qkv)
|
| 290 |
+
h = self.proj_out(h)
|
| 291 |
+
return (x + h).reshape(b, c, *spatial)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def count_flops_attn(model, _x, y):
|
| 295 |
+
"""
|
| 296 |
+
A counter for the `thop` package to count the operations in an
|
| 297 |
+
attention operation.
|
| 298 |
+
Meant to be used like:
|
| 299 |
+
macs, params = thop.profile(
|
| 300 |
+
model,
|
| 301 |
+
inputs=(inputs, timestamps),
|
| 302 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
| 303 |
+
)
|
| 304 |
+
"""
|
| 305 |
+
b, c, *spatial = y[0].shape
|
| 306 |
+
num_spatial = int(np.prod(spatial))
|
| 307 |
+
# We perform two matmuls with the same number of ops.
|
| 308 |
+
# The first computes the weight matrix, the second computes
|
| 309 |
+
# the combination of the value vectors.
|
| 310 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
| 311 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class QKVAttentionLegacy(nn.Module):
|
| 315 |
+
"""
|
| 316 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(self, n_heads):
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.n_heads = n_heads
|
| 322 |
+
|
| 323 |
+
def forward(self, qkv):
|
| 324 |
+
"""
|
| 325 |
+
Apply QKV attention.
|
| 326 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 327 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 328 |
+
"""
|
| 329 |
+
bs, width, length = qkv.shape
|
| 330 |
+
assert width % (3 * self.n_heads) == 0
|
| 331 |
+
ch = width // (3 * self.n_heads)
|
| 332 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| 333 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 334 |
+
weight = th.einsum(
|
| 335 |
+
"bct,bcs->bts", q * scale, k * scale
|
| 336 |
+
) # More stable with f16 than dividing afterwards
|
| 337 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 338 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
| 339 |
+
return a.reshape(bs, -1, length)
|
| 340 |
+
|
| 341 |
+
@staticmethod
|
| 342 |
+
def count_flops(model, _x, y):
|
| 343 |
+
return count_flops_attn(model, _x, y)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class QKVAttention(nn.Module):
|
| 347 |
+
"""
|
| 348 |
+
A module which performs QKV attention and splits in a different order.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(self, n_heads):
|
| 352 |
+
super().__init__()
|
| 353 |
+
self.n_heads = n_heads
|
| 354 |
+
|
| 355 |
+
def forward(self, qkv):
|
| 356 |
+
"""
|
| 357 |
+
Apply QKV attention.
|
| 358 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
| 359 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 360 |
+
"""
|
| 361 |
+
bs, width, length = qkv.shape
|
| 362 |
+
assert width % (3 * self.n_heads) == 0
|
| 363 |
+
ch = width // (3 * self.n_heads)
|
| 364 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 365 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 366 |
+
weight = th.einsum(
|
| 367 |
+
"bct,bcs->bts",
|
| 368 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
| 369 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
| 370 |
+
) # More stable with f16 than dividing afterwards
|
| 371 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 372 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
| 373 |
+
return a.reshape(bs, -1, length)
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def count_flops(model, _x, y):
|
| 377 |
+
return count_flops_attn(model, _x, y)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Timestep(nn.Module):
|
| 381 |
+
def __init__(self, dim):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.dim = dim
|
| 384 |
+
|
| 385 |
+
def forward(self, t):
|
| 386 |
+
return timestep_embedding(t, self.dim)
|
refnet/ldm/util.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adopted from
|
| 2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 3 |
+
# and
|
| 4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 5 |
+
# and
|
| 6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
| 7 |
+
#
|
| 8 |
+
# thanks!
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
from einops import repeat
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def rescale_zero_terminal_snr(betas):
|
| 19 |
+
"""
|
| 20 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
betas (`torch.FloatTensor`):
|
| 25 |
+
the betas that the scheduler is being initialized with.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
| 29 |
+
"""
|
| 30 |
+
# Convert betas to alphas_bar_sqrt
|
| 31 |
+
alphas = 1.0 - betas
|
| 32 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 33 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 34 |
+
|
| 35 |
+
# Store old values.
|
| 36 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 37 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 38 |
+
|
| 39 |
+
# Shift so the last timestep is zero.
|
| 40 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 41 |
+
|
| 42 |
+
# Scale so the first timestep is back to the old value.
|
| 43 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 44 |
+
|
| 45 |
+
# Convert alphas_bar_sqrt to betas
|
| 46 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 47 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 48 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 49 |
+
betas = 1 - alphas
|
| 50 |
+
|
| 51 |
+
return betas
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
|
| 55 |
+
if schedule == "linear":
|
| 56 |
+
betas = (
|
| 57 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 58 |
+
)
|
| 59 |
+
elif schedule == "scaled_linear":
|
| 60 |
+
betas = (
|
| 61 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
elif schedule == "cosine":
|
| 65 |
+
timesteps = (
|
| 66 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
| 67 |
+
)
|
| 68 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 69 |
+
alphas = torch.cos(alphas).pow(2)
|
| 70 |
+
alphas = alphas / alphas[0]
|
| 71 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
| 72 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 73 |
+
|
| 74 |
+
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
| 75 |
+
# return early
|
| 76 |
+
return betas_for_alpha_bar(
|
| 77 |
+
n_timestep,
|
| 78 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
elif schedule == "sqrt_linear":
|
| 82 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 83 |
+
elif schedule == "sqrt":
|
| 84 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 87 |
+
|
| 88 |
+
if zero_snr:
|
| 89 |
+
betas = rescale_zero_terminal_snr(betas)
|
| 90 |
+
return betas.numpy()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
| 94 |
+
if ddim_discr_method == 'uniform':
|
| 95 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
| 96 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
| 97 |
+
elif ddim_discr_method == 'quad':
|
| 98 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
| 101 |
+
|
| 102 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
| 103 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
| 104 |
+
steps_out = ddim_timesteps + 1
|
| 105 |
+
if verbose:
|
| 106 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
| 107 |
+
return steps_out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
| 111 |
+
# select alphas for computing the variance schedule
|
| 112 |
+
alphas = alphacums[ddim_timesteps]
|
| 113 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
| 114 |
+
|
| 115 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
| 116 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
| 117 |
+
if verbose:
|
| 118 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
| 119 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
| 120 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
| 121 |
+
return sigmas, alphas, alphas_prev
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 125 |
+
"""
|
| 126 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 127 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 128 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 129 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 130 |
+
produces the cumulative product of (1-beta) up to that
|
| 131 |
+
part of the diffusion process.
|
| 132 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 133 |
+
prevent singularities.
|
| 134 |
+
"""
|
| 135 |
+
betas = []
|
| 136 |
+
for i in range(num_diffusion_timesteps):
|
| 137 |
+
t1 = i / num_diffusion_timesteps
|
| 138 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 139 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 140 |
+
return np.array(betas)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def extract_into_tensor(a, t, x_shape):
|
| 144 |
+
b, *_ = t.shape
|
| 145 |
+
out = a.gather(-1, t)
|
| 146 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 151 |
+
@staticmethod
|
| 152 |
+
def forward(ctx, run_function, length, *args):
|
| 153 |
+
ctx.run_function = run_function
|
| 154 |
+
ctx.input_tensors = list(args[:length])
|
| 155 |
+
ctx.input_params = list(args[length:])
|
| 156 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
| 157 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
| 158 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 161 |
+
return output_tensors
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def backward(ctx, *output_grads):
|
| 165 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 166 |
+
with torch.enable_grad(), \
|
| 167 |
+
torch.amp.autocast("cuda", **ctx.gpu_autocast_kwargs):
|
| 168 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 169 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 170 |
+
# Tensors.
|
| 171 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 172 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 173 |
+
input_grads = torch.autograd.grad(
|
| 174 |
+
output_tensors,
|
| 175 |
+
ctx.input_tensors + ctx.input_params,
|
| 176 |
+
output_grads,
|
| 177 |
+
allow_unused=True,
|
| 178 |
+
)
|
| 179 |
+
del ctx.input_tensors
|
| 180 |
+
del ctx.input_params
|
| 181 |
+
del output_tensors
|
| 182 |
+
return (None, None) + input_grads
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 186 |
+
"""
|
| 187 |
+
Create sinusoidal timestep embeddings.
|
| 188 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 189 |
+
These may be fractional.
|
| 190 |
+
:param dim: the dimension of the output.
|
| 191 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 192 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 193 |
+
"""
|
| 194 |
+
if not repeat_only:
|
| 195 |
+
half = dim // 2
|
| 196 |
+
freqs = torch.exp(
|
| 197 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 198 |
+
).to(device=timesteps.device)
|
| 199 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 200 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 201 |
+
if dim % 2:
|
| 202 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 203 |
+
else:
|
| 204 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
| 205 |
+
return embedding
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def zero_module(module):
|
| 209 |
+
"""
|
| 210 |
+
Zero out the parameters of a module and return it.
|
| 211 |
+
"""
|
| 212 |
+
for p in module.parameters():
|
| 213 |
+
p.detach().zero_()
|
| 214 |
+
return module
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def scale_module(module, scale):
|
| 218 |
+
"""
|
| 219 |
+
Scale the parameters of a module and return it.
|
| 220 |
+
"""
|
| 221 |
+
for p in module.parameters():
|
| 222 |
+
p.detach().mul_(scale)
|
| 223 |
+
return module
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def mean_flat(tensor):
|
| 227 |
+
"""
|
| 228 |
+
Take the mean over all non-batch dimensions.
|
| 229 |
+
"""
|
| 230 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def normalization(channels):
|
| 234 |
+
"""
|
| 235 |
+
Make a standard normalization layer.
|
| 236 |
+
:param channels: number of input channels.
|
| 237 |
+
:return: an nn.Module for normalization.
|
| 238 |
+
"""
|
| 239 |
+
return GroupNorm32(32, channels)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
| 243 |
+
class SiLU(nn.Module):
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
return x * torch.sigmoid(x)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class GroupNorm32(nn.GroupNorm):
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
return super().forward(x.to(self.weight.dtype)).type(x.dtype)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def conv_nd(dims, *args, **kwargs):
|
| 254 |
+
"""
|
| 255 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 256 |
+
"""
|
| 257 |
+
if dims == 1:
|
| 258 |
+
return nn.Conv1d(*args, **kwargs)
|
| 259 |
+
elif dims == 2:
|
| 260 |
+
return nn.Conv2d(*args, **kwargs)
|
| 261 |
+
elif dims == 3:
|
| 262 |
+
return nn.Conv3d(*args, **kwargs)
|
| 263 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def linear(*args, **kwargs):
|
| 267 |
+
"""
|
| 268 |
+
Create a linear module.
|
| 269 |
+
"""
|
| 270 |
+
return nn.Linear(*args, **kwargs)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 274 |
+
"""
|
| 275 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 276 |
+
"""
|
| 277 |
+
if dims == 1:
|
| 278 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 279 |
+
elif dims == 2:
|
| 280 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 281 |
+
elif dims == 3:
|
| 282 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 283 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def noise_like(shape, device, repeat=False):
|
| 287 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
| 288 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 289 |
+
return repeat_noise() if repeat else noise()
|
refnet/modules/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def wd_v14_swin2_tagger_config():
|
| 5 |
+
CustomConfig = namedtuple('CustomConfig', [
|
| 6 |
+
'architecture', 'num_classes', 'num_features', 'global_pool', 'model_args', 'pretrained_cfg'
|
| 7 |
+
])
|
| 8 |
+
|
| 9 |
+
custom_config = CustomConfig(
|
| 10 |
+
architecture="swinv2_base_window8_256",
|
| 11 |
+
num_classes=9083,
|
| 12 |
+
num_features=1024,
|
| 13 |
+
global_pool="avg",
|
| 14 |
+
model_args={
|
| 15 |
+
"act_layer": "gelu",
|
| 16 |
+
"img_size": 448,
|
| 17 |
+
"window_size": 14
|
| 18 |
+
},
|
| 19 |
+
pretrained_cfg={
|
| 20 |
+
"custom_load": False,
|
| 21 |
+
"input_size": [3, 448, 448],
|
| 22 |
+
"fixed_input_size": False,
|
| 23 |
+
"interpolation": "bicubic",
|
| 24 |
+
"crop_pct": 1.0,
|
| 25 |
+
"crop_mode": "center",
|
| 26 |
+
"mean": [0.5, 0.5, 0.5],
|
| 27 |
+
"std": [0.5, 0.5, 0.5],
|
| 28 |
+
"num_classes": 9083,
|
| 29 |
+
"pool_size": None,
|
| 30 |
+
"first_conv": None,
|
| 31 |
+
"classifier": None
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
return custom_config
|
refnet/modules/attention.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from calendar import c
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from refnet.util import exists, default, checkpoint_wrapper
|
| 6 |
+
from .layers import RMSNorm
|
| 7 |
+
from .attn_utils import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_masked_attention_bias(
|
| 11 |
+
mask: torch.Tensor,
|
| 12 |
+
threshold: float,
|
| 13 |
+
num_heads: int,
|
| 14 |
+
context_len: int
|
| 15 |
+
):
|
| 16 |
+
b, seq_len, _ = mask.shape
|
| 17 |
+
half_len = context_len // 2
|
| 18 |
+
|
| 19 |
+
if context_len % 8 != 0:
|
| 20 |
+
padded_context_len = ((context_len + 7) // 8) * 8
|
| 21 |
+
else:
|
| 22 |
+
padded_context_len = context_len
|
| 23 |
+
|
| 24 |
+
fg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
|
| 25 |
+
bg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
|
| 26 |
+
|
| 27 |
+
fg_bias[:, :, half_len:] = -float('inf')
|
| 28 |
+
bg_bias[:, :, :half_len] = -float('inf')
|
| 29 |
+
attn_bias = torch.where(mask > threshold, fg_bias, bg_bias)
|
| 30 |
+
return attn_bias.unsqueeze(1).repeat_interleave(num_heads, dim=1)
|
| 31 |
+
|
| 32 |
+
class Identity(nn.Module):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
def forward(self, x, *args, **kwargs):
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Rotary Positional Embeddings implementation
|
| 41 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
| 42 |
+
def __init__(self, dim, max_seq_len=1024, theta=10000.0):
|
| 43 |
+
super().__init__()
|
| 44 |
+
assert dim % 2 == 0, "Dimension must be divisible by 2"
|
| 45 |
+
dim = dim // 2
|
| 46 |
+
self.max_seq_len = max_seq_len
|
| 47 |
+
freqs = torch.outer(
|
| 48 |
+
torch.arange(self.max_seq_len),
|
| 49 |
+
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
|
| 50 |
+
)
|
| 51 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 52 |
+
self.register_buffer("freq_h", freqs, persistent=False)
|
| 53 |
+
self.register_buffer("freq_w", freqs, persistent=False)
|
| 54 |
+
|
| 55 |
+
def forward(self, x, grid_size):
|
| 56 |
+
bs, seq_len, heads = x.shape[:3]
|
| 57 |
+
h, w = grid_size
|
| 58 |
+
|
| 59 |
+
x_complex = torch.view_as_complex(
|
| 60 |
+
x.float().reshape(bs, seq_len, heads, -1, 2)
|
| 61 |
+
)
|
| 62 |
+
freqs = torch.cat([
|
| 63 |
+
self.freq_h[:h].view(1, h, 1, -1).expand(bs, h, w, -1),
|
| 64 |
+
self.freq_w[:w].view(1, 1, w, -1).expand(bs, h, w, -1)
|
| 65 |
+
], dim=-1).reshape(bs, seq_len, 1, -1)
|
| 66 |
+
|
| 67 |
+
x_out = x_complex * freqs
|
| 68 |
+
x_out = torch.view_as_real(x_out).flatten(3)
|
| 69 |
+
|
| 70 |
+
return x_out.type_as(x)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MemoryEfficientAttention(nn.Module):
|
| 74 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
query_dim,
|
| 78 |
+
context_dim = None,
|
| 79 |
+
heads = None,
|
| 80 |
+
dim_head = 64,
|
| 81 |
+
dropout = 0.0,
|
| 82 |
+
log = False,
|
| 83 |
+
causal = False,
|
| 84 |
+
rope = False,
|
| 85 |
+
max_seq_len = 1024,
|
| 86 |
+
qk_norm = False,
|
| 87 |
+
**kwargs
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
if log:
|
| 91 |
+
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
| 92 |
+
f"{heads} heads.")
|
| 93 |
+
|
| 94 |
+
heads = heads or query_dim // dim_head
|
| 95 |
+
inner_dim = dim_head * heads
|
| 96 |
+
context_dim = default(context_dim, query_dim)
|
| 97 |
+
|
| 98 |
+
self.heads = heads
|
| 99 |
+
self.dim_head = dim_head
|
| 100 |
+
|
| 101 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 102 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
| 103 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
| 104 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 105 |
+
|
| 106 |
+
self.q_norm = RMSNorm(inner_dim) if qk_norm else Identity()
|
| 107 |
+
self.k_norm = RMSNorm(inner_dim) if qk_norm else Identity()
|
| 108 |
+
self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_len=max_seq_len) if rope else Identity()
|
| 109 |
+
self.attn_ops = causal_ops if causal else {}
|
| 110 |
+
|
| 111 |
+
# default setting for split cross-attention
|
| 112 |
+
self.bg_scale = 1.
|
| 113 |
+
self.fg_scale = 1.
|
| 114 |
+
self.merge_scale = 0.
|
| 115 |
+
self.mask_threshold = 0.05
|
| 116 |
+
|
| 117 |
+
@checkpoint_wrapper
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
x,
|
| 121 |
+
context=None,
|
| 122 |
+
mask=None,
|
| 123 |
+
scale=1.,
|
| 124 |
+
scale_factor=None,
|
| 125 |
+
grid_size=None,
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
context = default(context, x)
|
| 129 |
+
|
| 130 |
+
if exists(mask):
|
| 131 |
+
out = self.masked_forward(x, context, mask, scale, scale_factor)
|
| 132 |
+
else:
|
| 133 |
+
q = self.to_q(x)
|
| 134 |
+
k = self.to_k(context)
|
| 135 |
+
v = self.to_v(context)
|
| 136 |
+
out = self.attn_forward(q, k, v, scale, grid_size)
|
| 137 |
+
|
| 138 |
+
return self.to_out(out)
|
| 139 |
+
|
| 140 |
+
def attn_forward(self, q, k, v, scale=1., grid_size=None, mask=None):
|
| 141 |
+
q, k = map(
|
| 142 |
+
lambda t:
|
| 143 |
+
self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
|
| 144 |
+
(self.q_norm(q), self.k_norm(k))
|
| 145 |
+
)
|
| 146 |
+
v = rearrange(v, "b n (h c) -> b n h c", h=self.heads)
|
| 147 |
+
out = attn_processor(q, k, v, attn_mask=mask, **self.attn_ops) * scale
|
| 148 |
+
out = rearrange(out, "b n h c -> b n (h c)")
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
def masked_forward(self, x, context, mask, scale=1., scale_factor=None):
|
| 152 |
+
# split cross-attention function
|
| 153 |
+
def qkv_forward(x, context):
|
| 154 |
+
q = self.to_q(x)
|
| 155 |
+
k = self.to_k(context)
|
| 156 |
+
v = self.to_v(context)
|
| 157 |
+
return q, k, v
|
| 158 |
+
|
| 159 |
+
assert exists(scale_factor), "Scale factor must be assigned before masked attention"
|
| 160 |
+
mask = rearrange(
|
| 161 |
+
F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
|
| 162 |
+
"b c h w -> b (h w) c"
|
| 163 |
+
).contiguous()
|
| 164 |
+
|
| 165 |
+
if self.merge_scale > 0:
|
| 166 |
+
# split cross-attention with merging scale, need two times forward
|
| 167 |
+
c1, c2 = context.chunk(2, dim=1)
|
| 168 |
+
|
| 169 |
+
# Background region cross-attention
|
| 170 |
+
q2, k2, v2 = qkv_forward(x, c2)
|
| 171 |
+
bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
|
| 172 |
+
|
| 173 |
+
# Foreground region cross-attention
|
| 174 |
+
q1, k1, v1 = qkv_forward(x, c1)
|
| 175 |
+
fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
|
| 176 |
+
|
| 177 |
+
fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
|
| 178 |
+
return torch.where(mask < self.mask_threshold, bg_out, fg_out)
|
| 179 |
+
|
| 180 |
+
else:
|
| 181 |
+
attn_mask = create_masked_attention_bias(
|
| 182 |
+
mask, self.mask_threshold, self.heads, context.size(1)
|
| 183 |
+
)
|
| 184 |
+
q, k, v = qkv_forward(x, context)
|
| 185 |
+
return self.attn_forward(q, k, v, mask=attn_mask) * scale
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class MultiModalAttention(MemoryEfficientAttention):
|
| 189 |
+
def __init__(self, query_dim, context_dim_2, heads=8, dim_head=64, qk_norm=False, *args, **kwargs):
|
| 190 |
+
super().__init__(query_dim, heads=heads, dim_head=dim_head, qk_norm=qk_norm, *args, **kwargs)
|
| 191 |
+
inner_dim = dim_head * heads
|
| 192 |
+
self.to_k_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
|
| 193 |
+
self.to_v_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
|
| 194 |
+
self.k2_norm = RMSNorm(inner_dim) if qk_norm else Identity()
|
| 195 |
+
|
| 196 |
+
def forward(self, x, context=None, mask=None, scale=1., grid_size=None):
|
| 197 |
+
if not isinstance(scale, list) and not isinstance(scale, tuple):
|
| 198 |
+
scale = (scale, scale)
|
| 199 |
+
assert len(context.shape) == 4, "Multi-modal attention requires different context inputs to be (b, m, n c)"
|
| 200 |
+
context, context2 = context.chunk(2, dim=1)
|
| 201 |
+
|
| 202 |
+
q = self.to_q(x)
|
| 203 |
+
k = self.to_k(context)
|
| 204 |
+
v = self.to_v(context)
|
| 205 |
+
k2 = self.to_k_2(context2)
|
| 206 |
+
v2 = self.to_k_2(context2)
|
| 207 |
+
|
| 208 |
+
b, _, _ = q.shape
|
| 209 |
+
q, k, k2 = map(
|
| 210 |
+
lambda t: self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
|
| 211 |
+
(self.q_norm(q), self.k_norm(k), self.k2_norm(k2))
|
| 212 |
+
)
|
| 213 |
+
v, v2 = map(lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), (v, v2))
|
| 214 |
+
|
| 215 |
+
out = (attn_processor(q, k, v, **self.attn_ops) * scale[0] +
|
| 216 |
+
attn_processor(q, k2, v2, **self.attn_ops) * scale[1])
|
| 217 |
+
|
| 218 |
+
if exists(mask):
|
| 219 |
+
raise NotImplementedError
|
| 220 |
+
out = rearrange(out, "b n h c -> b n (h c)")
|
| 221 |
+
return self.to_out(out)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MultiScaleCausalAttention(MemoryEfficientAttention):
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
x,
|
| 228 |
+
context=None,
|
| 229 |
+
mask=None,
|
| 230 |
+
scale=1.,
|
| 231 |
+
scale_factor=None,
|
| 232 |
+
grid_size=None,
|
| 233 |
+
token_lens=None
|
| 234 |
+
):
|
| 235 |
+
context = default(context, x)
|
| 236 |
+
q = self.to_q(x)
|
| 237 |
+
k = self.to_k(context)
|
| 238 |
+
v = self.to_v(context)
|
| 239 |
+
out = self.attn_forward(q, k, v, scale, grid_size=grid_size, token_lens=token_lens)
|
| 240 |
+
return self.to_out(out)
|
| 241 |
+
|
| 242 |
+
def attn_forward(self, q, k, v, scale = 1., grid_size = None, token_lens = None):
|
| 243 |
+
q, k, v = map(
|
| 244 |
+
lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads),
|
| 245 |
+
(self.q_norm(q), self.k_norm(k), v)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
attn_output = []
|
| 249 |
+
prev_idx = 0
|
| 250 |
+
for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
|
| 251 |
+
end_idx = prev_idx + length + (idx == 0)
|
| 252 |
+
rope_prev_idx = prev_idx + (idx == 0)
|
| 253 |
+
rope_slice = slice(rope_prev_idx, end_idx)
|
| 254 |
+
|
| 255 |
+
q[:, rope_slice] = self.rope(q[:, rope_slice], grid)
|
| 256 |
+
k[:, rope_slice] = self.rope(k[:, rope_slice], grid)
|
| 257 |
+
qs = q[:, prev_idx: end_idx]
|
| 258 |
+
ks, vs = map(lambda t: t[:, :end_idx], (k, v))
|
| 259 |
+
|
| 260 |
+
attn_output.append(attn_processor(qs.clone(), ks.clone(), vs.clone()) * scale)
|
| 261 |
+
prev_idx = end_idx
|
| 262 |
+
attn_output = rearrange(torch.cat(attn_output, 1), "b n h c -> b n (h c)")
|
| 263 |
+
return attn_output
|
| 264 |
+
|
| 265 |
+
# if FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_AVAILABLE:
|
| 266 |
+
# k_chunks = []
|
| 267 |
+
# v_chunks = []
|
| 268 |
+
# kv_token_lens = []
|
| 269 |
+
# prev_idx = 0
|
| 270 |
+
# for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
|
| 271 |
+
# end_idx = prev_idx + length + (idx == 0)
|
| 272 |
+
# rope_prev_idx = prev_idx + (idx == 0)
|
| 273 |
+
|
| 274 |
+
# rope_slice = slice(rope_prev_idx, end_idx)
|
| 275 |
+
# q[:, rope_slice], k[:, rope_slice], v[:, rope_slice] = map(
|
| 276 |
+
# lambda t: self.rope(t[:, rope_slice], grid),
|
| 277 |
+
# (q, k, v)
|
| 278 |
+
# )
|
| 279 |
+
# kv_token_lens.append(end_idx+1)
|
| 280 |
+
# k_chunks.append(k[:, :end_idx])
|
| 281 |
+
# v_chunks.append(v[:, :end_idx])
|
| 282 |
+
# prev_idx = end_idx
|
| 283 |
+
# k = torch.cat(k_chunks, 1)
|
| 284 |
+
# v = torch.cat(v_chunks, 1)
|
| 285 |
+
# B, N, H, C = q.shape
|
| 286 |
+
# token_lens = torch.tensor(token_lens, device=q.device, dtype=torch.int32)
|
| 287 |
+
# kv_token_lens = torch.tensor(kv_token_lens, device=q.device, dtype=torch.int32)
|
| 288 |
+
# token_lens[0] = token_lens[0] + 1
|
| 289 |
+
#
|
| 290 |
+
# cu_seqlens_q, cu_seqlens_kv = map(lambda t:
|
| 291 |
+
# torch.cat([t.new_zeros([1]), t]).cumsum(0, dtype=torch.int32),
|
| 292 |
+
# (token_lens, kv_token_lens)
|
| 293 |
+
# )
|
| 294 |
+
# max_seqlen_q, max_seqlen_kv = map(lambda t: int(t.max()), (token_lens, kv_token_lens))
|
| 295 |
+
#
|
| 296 |
+
# q_flat = q.reshape(-1, H, C).contiguous()
|
| 297 |
+
# k_flat = k.reshape(-1, H, C).contiguous()
|
| 298 |
+
# v_flat = v.reshape(-1, H, C).contiguous()
|
| 299 |
+
# out_flat = flash_attn_varlen_func(
|
| 300 |
+
# q=q_flat, k=k_flat, v=v_flat,
|
| 301 |
+
# cu_seqlens_q=cu_seqlens_q,
|
| 302 |
+
# cu_seqlens_k=cu_seqlens_kv,
|
| 303 |
+
# max_seqlen_q=max_seqlen_q,
|
| 304 |
+
# max_seqlen_k=max_seqlen_kv,
|
| 305 |
+
# causal=True,
|
| 306 |
+
# )
|
| 307 |
+
#
|
| 308 |
+
# out = rearrange(out_flat, "(b n) h c -> b n (h c)", b=B, n=N)
|
| 309 |
+
# return out * scale
|
refnet/modules/attn_utils.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
ATTN_PRECISION = torch.float16
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import flash_attn_interface
|
| 8 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 9 |
+
FLASH_ATTN_AVAILABLE = False
|
| 10 |
+
|
| 11 |
+
except ModuleNotFoundError:
|
| 12 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 13 |
+
try:
|
| 14 |
+
import flash_attn
|
| 15 |
+
FLASH_ATTN_AVAILABLE = True
|
| 16 |
+
except ModuleNotFoundError:
|
| 17 |
+
FLASH_ATTN_AVAILABLE = False
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import xformers.ops
|
| 21 |
+
XFORMERS_IS_AVAILBLE = True
|
| 22 |
+
except:
|
| 23 |
+
XFORMERS_IS_AVAILBLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def half(x):
|
| 27 |
+
if x.dtype not in [torch.float16, torch.bfloat16]:
|
| 28 |
+
x = x.to(ATTN_PRECISION)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
def attn_processor(q, k, v, attn_mask = None, *args, **kwargs):
|
| 32 |
+
if attn_mask is not None:
|
| 33 |
+
if XFORMERS_IS_AVAILBLE:
|
| 34 |
+
out = xformers.ops.memory_efficient_attention(
|
| 35 |
+
q, k, v, attn_bias=attn_mask, *args, **kwargs
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
| 39 |
+
out = F.scaled_dot_product_attention(
|
| 40 |
+
q, k, v, attn_mask=attn_mask, *args, **kwargs
|
| 41 |
+
).transpose(1, 2)
|
| 42 |
+
else:
|
| 43 |
+
if FLASH_ATTN_3_AVAILABLE:
|
| 44 |
+
dtype = v.dtype
|
| 45 |
+
q, k, v = map(lambda t: half(t), (q, k, v))
|
| 46 |
+
out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype)
|
| 47 |
+
elif FLASH_ATTN_AVAILABLE:
|
| 48 |
+
dtype = v.dtype
|
| 49 |
+
q, k, v = map(lambda t: half(t), (q, k, v))
|
| 50 |
+
out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype)
|
| 51 |
+
elif XFORMERS_IS_AVAILBLE:
|
| 52 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs)
|
| 53 |
+
else:
|
| 54 |
+
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
| 55 |
+
out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2)
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def flash_attn_varlen_func(q, k, v, **kwargs):
|
| 60 |
+
if FLASH_ATTN_3_AVAILABLE:
|
| 61 |
+
return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0]
|
| 62 |
+
else:
|
| 63 |
+
return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5):
|
| 67 |
+
"""
|
| 68 |
+
Split input tensor into foreground and background based on mask, then concatenate them.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
tensor: Input tensor of shape (batch, seq_len, dim)
|
| 72 |
+
mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len)
|
| 73 |
+
threshold: Threshold for mask binarization
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
split_tensor: Concatenated tensor with foreground first, then background
|
| 77 |
+
fg_indices: Indices of foreground elements for restoration
|
| 78 |
+
bg_indices: Indices of background elements for restoration
|
| 79 |
+
original_shape: Original tensor shape for restoration
|
| 80 |
+
"""
|
| 81 |
+
batch_size, seq_len, *dims = tensor.shape
|
| 82 |
+
device, dtype = tensor.device, tensor.dtype
|
| 83 |
+
|
| 84 |
+
# Ensure mask has correct shape and binarize
|
| 85 |
+
if mask.dim() == 2:
|
| 86 |
+
mask = mask.unsqueeze(-1)
|
| 87 |
+
binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len)
|
| 88 |
+
|
| 89 |
+
# Store indices for restoration (keep minimal loop for complex indexing)
|
| 90 |
+
fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)]
|
| 91 |
+
bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)]
|
| 92 |
+
|
| 93 |
+
# Count elements efficiently
|
| 94 |
+
fg_counts = binary_mask.sum(dim=1)
|
| 95 |
+
bg_counts = (~binary_mask).sum(dim=1)
|
| 96 |
+
max_fg_len = fg_counts.max().item()
|
| 97 |
+
max_bg_len = bg_counts.max().item()
|
| 98 |
+
|
| 99 |
+
# Early exit if no elements
|
| 100 |
+
if max_fg_len == 0 and max_bg_len == 0:
|
| 101 |
+
return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape
|
| 102 |
+
|
| 103 |
+
# Create output tensor
|
| 104 |
+
split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype)
|
| 105 |
+
|
| 106 |
+
# Vectorized approach using gather for better efficiency
|
| 107 |
+
for b in range(batch_size):
|
| 108 |
+
if len(fg_indices[b]) > 0:
|
| 109 |
+
split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]]
|
| 110 |
+
if len(bg_indices[b]) > 0:
|
| 111 |
+
split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]]
|
| 112 |
+
|
| 113 |
+
return split_tensor, fg_indices, bg_indices, tensor.shape
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list,
|
| 117 |
+
original_shape: torch.Size):
|
| 118 |
+
"""
|
| 119 |
+
Restore original tensor from split tensor using stored indices.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
split_tensor: Split tensor from split_tensor_by_mask
|
| 123 |
+
fg_indices: List of foreground indices for each batch
|
| 124 |
+
bg_indices: List of background indices for each batch
|
| 125 |
+
original_shape: Original tensor shape
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
restored_tensor: Restored tensor with original shape and ordering
|
| 129 |
+
"""
|
| 130 |
+
batch_size, seq_len = original_shape[:2]
|
| 131 |
+
dims = original_shape[2:]
|
| 132 |
+
device, dtype = split_tensor.device, split_tensor.dtype
|
| 133 |
+
|
| 134 |
+
# Calculate split point efficiently
|
| 135 |
+
max_fg_len = max((len(fg) for fg in fg_indices), default=0)
|
| 136 |
+
|
| 137 |
+
# Initialize restored tensor
|
| 138 |
+
restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype)
|
| 139 |
+
|
| 140 |
+
# Early exit if no elements to restore
|
| 141 |
+
if split_tensor.shape[1] == 0:
|
| 142 |
+
return restored_tensor
|
| 143 |
+
|
| 144 |
+
# Split tensor parts
|
| 145 |
+
fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None
|
| 146 |
+
bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None
|
| 147 |
+
|
| 148 |
+
# Restore in single loop with efficient indexing
|
| 149 |
+
for b in range(batch_size):
|
| 150 |
+
if fg_part is not None and len(fg_indices[b]) > 0:
|
| 151 |
+
restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])]
|
| 152 |
+
if bg_part is not None and len(bg_indices[b]) > 0:
|
| 153 |
+
restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])]
|
| 154 |
+
|
| 155 |
+
return restored_tensor
|
refnet/modules/embedder.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from refnet.util import exists, append_dims
|
| 8 |
+
from refnet.sampling import tps_warp
|
| 9 |
+
from refnet.ldm.openaimodel import Timestep, zero_module
|
| 10 |
+
|
| 11 |
+
import timm
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torchvision.transforms
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
from torch.utils.checkpoint import checkpoint
|
| 19 |
+
from safetensors.torch import load_file
|
| 20 |
+
from transformers import (
|
| 21 |
+
T5EncoderModel,
|
| 22 |
+
T5Tokenizer,
|
| 23 |
+
CLIPVisionModelWithProjection,
|
| 24 |
+
CLIPTextModel,
|
| 25 |
+
CLIPTokenizer,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
versions = {
|
| 29 |
+
"ViT-bigG-14": "laion2b_s39b_b160k",
|
| 30 |
+
"ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32
|
| 31 |
+
"ViT-L-14": "laion2b_s32b_b82k",
|
| 32 |
+
"hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H]
|
| 33 |
+
}
|
| 34 |
+
hf_versions = {
|
| 35 |
+
"ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
| 36 |
+
"ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 37 |
+
"ViT-L-14": "openai/clip-vit-large-patch14",
|
| 38 |
+
}
|
| 39 |
+
cache_dir = os.environ.get("HF_HOME", "./pretrained_models")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class WDv14SwinTransformerV2(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
WD-v14-tagger
|
| 45 |
+
Author: Smiling Wolf
|
| 46 |
+
Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2
|
| 47 |
+
"""
|
| 48 |
+
negative_logit = -22
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
input_size = 448,
|
| 53 |
+
antialias = True,
|
| 54 |
+
layer_idx = 0.,
|
| 55 |
+
load_tag = False,
|
| 56 |
+
logit_threshold = None,
|
| 57 |
+
direct_forward = False,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
input_size: Input image size
|
| 63 |
+
antialias: Antialias during rescaling
|
| 64 |
+
layer_idx: Extracted feature layer
|
| 65 |
+
load_tag: Set it to true if use the embedder for image classification
|
| 66 |
+
logit_threshold: Filtering specific channels in logits output
|
| 67 |
+
"""
|
| 68 |
+
from refnet.modules import wd_v14_swin2_tagger_config
|
| 69 |
+
super().__init__()
|
| 70 |
+
custom_config = wd_v14_swin2_tagger_config()
|
| 71 |
+
self.model: nn.Module = timm.create_model(
|
| 72 |
+
custom_config.architecture,
|
| 73 |
+
pretrained = False,
|
| 74 |
+
num_classes = custom_config.num_classes,
|
| 75 |
+
global_pool = custom_config.global_pool,
|
| 76 |
+
**custom_config.model_args
|
| 77 |
+
)
|
| 78 |
+
self.image_size = input_size
|
| 79 |
+
self.antialias = antialias
|
| 80 |
+
self.layer_idx = layer_idx
|
| 81 |
+
self.load_tag = load_tag
|
| 82 |
+
self.logit_threshold = logit_threshold
|
| 83 |
+
self.direct_forward = direct_forward
|
| 84 |
+
|
| 85 |
+
self.load_from_pretrained_url(load_tag)
|
| 86 |
+
self.get_transformer_length()
|
| 87 |
+
self.model.eval()
|
| 88 |
+
self.model.requires_grad_(False)
|
| 89 |
+
|
| 90 |
+
if self.direct_forward:
|
| 91 |
+
self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_from_pretrained_url(self, load_tag=False):
|
| 95 |
+
import pandas as pd
|
| 96 |
+
from torch.hub import download_url_to_file
|
| 97 |
+
from data.tag_utils import load_labels, color_tag_index, geometry_tag_index
|
| 98 |
+
|
| 99 |
+
ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors")
|
| 100 |
+
if not os.path.exists(ckpt_path):
|
| 101 |
+
cache_path = os.path.join(cache_dir, "weights.tmp")
|
| 102 |
+
download_url_to_file(
|
| 103 |
+
"https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors",
|
| 104 |
+
dst = cache_path
|
| 105 |
+
)
|
| 106 |
+
os.rename(cache_path, ckpt_path)
|
| 107 |
+
|
| 108 |
+
if load_tag:
|
| 109 |
+
csv_path = hf_hub_download(
|
| 110 |
+
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
|
| 111 |
+
"selected_tags.csv",
|
| 112 |
+
cache_dir = cache_dir
|
| 113 |
+
# use_auth_token=HF_TOKEN,
|
| 114 |
+
)
|
| 115 |
+
tags_df = pd.read_csv(csv_path)
|
| 116 |
+
sep_tags = load_labels(tags_df)
|
| 117 |
+
|
| 118 |
+
self.tag_names = sep_tags[0]
|
| 119 |
+
self.rating_indexes = sep_tags[1]
|
| 120 |
+
self.general_indexes = sep_tags[2]
|
| 121 |
+
self.character_indexes = sep_tags[3]
|
| 122 |
+
|
| 123 |
+
self.color_tags = color_tag_index
|
| 124 |
+
self.expr_tags = geometry_tag_index
|
| 125 |
+
self.model.load_state_dict(load_file(ckpt_path))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85):
|
| 129 |
+
assert self.load_tag
|
| 130 |
+
labels = list(zip(self.tag_names, pred[0].astype(float)))
|
| 131 |
+
|
| 132 |
+
# First 4 labels are actually ratings: pick one with argmax
|
| 133 |
+
# ratings_names = [labels[i] for i in self.rating_indexes]
|
| 134 |
+
# rating = dict(ratings_names)
|
| 135 |
+
|
| 136 |
+
# Then we have general tags: pick any where prediction confidence > threshold
|
| 137 |
+
general_names = [labels[i] for i in self.general_indexes]
|
| 138 |
+
|
| 139 |
+
general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh]
|
| 140 |
+
general_res = dict(general_res)
|
| 141 |
+
|
| 142 |
+
# Everything else is characters: pick any where prediction confidence > threshold
|
| 143 |
+
character_names = [labels[i] for i in self.character_indexes]
|
| 144 |
+
|
| 145 |
+
character_res = [x for x in character_names if x[1] > character_thresh]
|
| 146 |
+
character_res = dict(character_res)
|
| 147 |
+
|
| 148 |
+
sorted_general_strings = sorted(
|
| 149 |
+
general_res.items(),
|
| 150 |
+
key=lambda x: x[1],
|
| 151 |
+
reverse=True,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
sorted_general_res = sorted(
|
| 155 |
+
general_res.items(),
|
| 156 |
+
key=lambda x: x[1],
|
| 157 |
+
reverse=True,
|
| 158 |
+
)
|
| 159 |
+
sorted_general_strings = [x[0] for x in sorted_general_strings]
|
| 160 |
+
sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
|
| 161 |
+
|
| 162 |
+
# return sorted_general_strings, rating, character_res, general_res
|
| 163 |
+
return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res
|
| 164 |
+
|
| 165 |
+
def get_transformer_length(self):
|
| 166 |
+
length = 0
|
| 167 |
+
for stage in self.model.layers:
|
| 168 |
+
length += len(stage.blocks)
|
| 169 |
+
self.transformer_length = length
|
| 170 |
+
|
| 171 |
+
def transformer_forward(self, x):
|
| 172 |
+
idx = 0
|
| 173 |
+
x = self.model.patch_embed(x)
|
| 174 |
+
for stage in self.model.layers:
|
| 175 |
+
x = stage.downsample(x)
|
| 176 |
+
for blk in stage.blocks:
|
| 177 |
+
if idx == self.transformer_length - self.layer_idx:
|
| 178 |
+
return x
|
| 179 |
+
if not torch.jit.is_scripting():
|
| 180 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 181 |
+
else:
|
| 182 |
+
x = blk(x)
|
| 183 |
+
idx += 1
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def forward(self, x, return_logits=False, pooled=True, **kwargs):
|
| 188 |
+
# x: [b, h, w, 3]
|
| 189 |
+
if self.direct_forward:
|
| 190 |
+
x = self.model(x)
|
| 191 |
+
else:
|
| 192 |
+
x = self.transformer_forward(x)
|
| 193 |
+
x = self.model.norm(x)
|
| 194 |
+
|
| 195 |
+
# x: [b, 14, 14, 1024]
|
| 196 |
+
if return_logits:
|
| 197 |
+
if pooled:
|
| 198 |
+
logits = self.model.forward_head(x).unsqueeze(1)
|
| 199 |
+
# x: [b, 1, 1024]
|
| 200 |
+
|
| 201 |
+
else:
|
| 202 |
+
logits = self.model.head.fc(x)
|
| 203 |
+
# x = F.sigmoid(x)
|
| 204 |
+
logits = rearrange(logits, "b h w c -> b (h w) c").contiguous()
|
| 205 |
+
# x: [b, 196, 9083]
|
| 206 |
+
|
| 207 |
+
# Need a threshold to cut off unnecessary classes.
|
| 208 |
+
if exists(self.logit_threshold) and isinstance(self.logit_threshold, float):
|
| 209 |
+
logits = torch.where(
|
| 210 |
+
logits > self.logit_threshold,
|
| 211 |
+
logits,
|
| 212 |
+
torch.ones_like(logits) * self.negative_logit
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
logits = None
|
| 217 |
+
|
| 218 |
+
if pooled:
|
| 219 |
+
x = x.mean(dim=[1, 2]).unsqueeze(1)
|
| 220 |
+
else:
|
| 221 |
+
x = rearrange(x, "b h w c -> b (h w) c").contiguous()
|
| 222 |
+
return [x, logits]
|
| 223 |
+
|
| 224 |
+
def preprocess(self, x: torch.Tensor):
|
| 225 |
+
x = F.interpolate(
|
| 226 |
+
x,
|
| 227 |
+
(self.image_size, self.image_size),
|
| 228 |
+
mode = "bicubic",
|
| 229 |
+
align_corners = True,
|
| 230 |
+
antialias = self.antialias
|
| 231 |
+
)
|
| 232 |
+
# convert RGB to BGR
|
| 233 |
+
x = x[:, [2, 1, 0]]
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
@torch.no_grad()
|
| 237 |
+
def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs):
|
| 238 |
+
# Input image must be in RGB format
|
| 239 |
+
return self(self.preprocess(img), return_logits, pooled)
|
| 240 |
+
|
| 241 |
+
@torch.no_grad()
|
| 242 |
+
def predict_labels(self, img: torch.Tensor, *args, **kwargs):
|
| 243 |
+
assert len(img.shape) == 4 and img.shape[0] == 1
|
| 244 |
+
logits = self(self.preprocess(img), return_logits=True, pooled=True)[1]
|
| 245 |
+
logits = F.sigmoid(logits).detach().cpu().numpy()
|
| 246 |
+
return self.convert_labels(logits, *args, **kwargs)
|
| 247 |
+
|
| 248 |
+
def geometry_update(self, emb, geometry_emb, scale_factor=1):
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
emb: WD embedding from reference image
|
| 253 |
+
geometry_emb: WD embedding from sketch image
|
| 254 |
+
|
| 255 |
+
"""
|
| 256 |
+
geometry_mask = torch.zeros_like(emb)
|
| 257 |
+
geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels
|
| 258 |
+
emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor
|
| 259 |
+
return emb
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def dtype(self):
|
| 263 |
+
return self.model.head.fc.weight.dtype
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class OpenCLIP(nn.Module):
|
| 267 |
+
def __init__(self, vision_config=None, text_config=None, **kwargs):
|
| 268 |
+
super().__init__()
|
| 269 |
+
if exists(vision_config):
|
| 270 |
+
vision_config.update(kwargs)
|
| 271 |
+
else:
|
| 272 |
+
vision_config = kwargs
|
| 273 |
+
|
| 274 |
+
if exists(text_config):
|
| 275 |
+
text_config.update(kwargs)
|
| 276 |
+
else:
|
| 277 |
+
text_config = kwargs
|
| 278 |
+
|
| 279 |
+
self.visual = FrozenOpenCLIPImageEmbedder(**vision_config)
|
| 280 |
+
self.transformer = FrozenOpenCLIPEmbedder(**text_config)
|
| 281 |
+
|
| 282 |
+
def preprocess(self, x):
|
| 283 |
+
return self.visual.preprocess(x)
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
def scale_factor(self):
|
| 287 |
+
return self.visual.scale_factor
|
| 288 |
+
|
| 289 |
+
def update_scale_factor(self, scale_factor):
|
| 290 |
+
self.visual.update_scale_factor(scale_factor)
|
| 291 |
+
|
| 292 |
+
def encode(self, *args, **kwargs):
|
| 293 |
+
return self.visual.encode(*args, **kwargs)
|
| 294 |
+
|
| 295 |
+
@torch.no_grad()
|
| 296 |
+
def encode_text(self, text, normalize=True):
|
| 297 |
+
return self.transformer(text, normalize)
|
| 298 |
+
|
| 299 |
+
def calculate_scale(self, v: torch.Tensor, t: torch.Tensor):
|
| 300 |
+
"""
|
| 301 |
+
Calculate the projection of v along the direction of t
|
| 302 |
+
params:
|
| 303 |
+
v: visual tokens from clip image encoder, shape: (b, n, c)
|
| 304 |
+
t: text features from clip text encoder (argmax -1), shape: (b, 1, c)
|
| 305 |
+
"""
|
| 306 |
+
return v @ t.mT
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class HFCLIPVisionModel(nn.Module):
|
| 311 |
+
# TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future
|
| 312 |
+
def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.model = CLIPVisionModelWithProjection.from_pretrained(
|
| 315 |
+
hf_versions[arch],
|
| 316 |
+
cache_dir = cache_dir
|
| 317 |
+
)
|
| 318 |
+
self.image_size = image_size
|
| 319 |
+
self.scale_factor = scale_factor
|
| 320 |
+
self.register_buffer(
|
| 321 |
+
'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False
|
| 322 |
+
)
|
| 323 |
+
self.register_buffer(
|
| 324 |
+
'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False
|
| 325 |
+
)
|
| 326 |
+
self.antialias = True
|
| 327 |
+
self.requires_grad_(False).eval()
|
| 328 |
+
|
| 329 |
+
def preprocess(self, x):
|
| 330 |
+
# normalize to [0,1]
|
| 331 |
+
ns = int(self.image_size * self.scale_factor)
|
| 332 |
+
x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias)
|
| 333 |
+
x = (x + 1.0) / 2.0
|
| 334 |
+
|
| 335 |
+
# renormalize according to clip
|
| 336 |
+
x = (x - self.mean) / self.std
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
def forward(self, x, output_type):
|
| 340 |
+
outputs = self.model(x).last_hidden_state
|
| 341 |
+
if output_type == "cls":
|
| 342 |
+
outputs = outputs[:, :1]
|
| 343 |
+
elif output_type == "local":
|
| 344 |
+
outputs = outputs[:, 1:]
|
| 345 |
+
outputs = self.model.vision_model.post_layernorm(outputs)
|
| 346 |
+
outputs = self.model.visual_projection(outputs)
|
| 347 |
+
return outputs
|
| 348 |
+
|
| 349 |
+
@torch.no_grad()
|
| 350 |
+
def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs):
|
| 351 |
+
img = self.preprocess(img) if preprocess else img
|
| 352 |
+
|
| 353 |
+
if warp_p > 0.:
|
| 354 |
+
rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim)
|
| 355 |
+
img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img))
|
| 356 |
+
return self(img, output_type)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class FrozenT5Embedder(nn.Module):
|
| 362 |
+
"""Uses the T5 transformer encoder for text"""
|
| 363 |
+
|
| 364 |
+
def __init__(
|
| 365 |
+
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
|
| 366 |
+
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir)
|
| 369 |
+
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir)
|
| 370 |
+
self.device = device
|
| 371 |
+
self.max_length = max_length
|
| 372 |
+
if freeze:
|
| 373 |
+
self.freeze()
|
| 374 |
+
|
| 375 |
+
def freeze(self):
|
| 376 |
+
self.transformer = self.transformer.eval()
|
| 377 |
+
|
| 378 |
+
for param in self.parameters():
|
| 379 |
+
param.requires_grad = False
|
| 380 |
+
|
| 381 |
+
def forward(self, text):
|
| 382 |
+
batch_encoding = self.tokenizer(
|
| 383 |
+
text,
|
| 384 |
+
truncation=True,
|
| 385 |
+
max_length=self.max_length,
|
| 386 |
+
return_length=True,
|
| 387 |
+
return_overflowing_tokens=False,
|
| 388 |
+
padding="max_length",
|
| 389 |
+
return_tensors="pt",
|
| 390 |
+
)
|
| 391 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
| 392 |
+
with torch.autocast("cuda", enabled=False):
|
| 393 |
+
outputs = self.transformer(input_ids=tokens)
|
| 394 |
+
z = outputs.last_hidden_state
|
| 395 |
+
return z
|
| 396 |
+
|
| 397 |
+
@torch.no_grad()
|
| 398 |
+
def encode(self, text):
|
| 399 |
+
return self(text)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class HFCLIPTextEmbedder(nn.Module):
|
| 403 |
+
def __init__(self, arch, freeze=True, device="cuda", max_length=77):
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
| 406 |
+
hf_versions[arch],
|
| 407 |
+
cache_dir = cache_dir
|
| 408 |
+
)
|
| 409 |
+
self.model = CLIPTextModel.from_pretrained(
|
| 410 |
+
hf_versions[arch],
|
| 411 |
+
cache_dir = cache_dir
|
| 412 |
+
)
|
| 413 |
+
self.device = device
|
| 414 |
+
self.max_length = max_length
|
| 415 |
+
if freeze:
|
| 416 |
+
self.freeze()
|
| 417 |
+
|
| 418 |
+
def freeze(self):
|
| 419 |
+
self.model = self.model.eval()
|
| 420 |
+
|
| 421 |
+
for param in self.parameters():
|
| 422 |
+
param.requires_grad = False
|
| 423 |
+
|
| 424 |
+
def forward(self, text):
|
| 425 |
+
if isinstance(text, torch.Tensor) and text.dtype == torch.long:
|
| 426 |
+
# Input is already tokenized
|
| 427 |
+
tokens = text
|
| 428 |
+
else:
|
| 429 |
+
# Need to tokenize text input
|
| 430 |
+
batch_encoding = self.tokenizer(
|
| 431 |
+
text,
|
| 432 |
+
truncation=True,
|
| 433 |
+
max_length=self.max_length,
|
| 434 |
+
padding="max_length",
|
| 435 |
+
return_tensors="pt",
|
| 436 |
+
)
|
| 437 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
| 438 |
+
|
| 439 |
+
outputs = self.model(input_ids=tokens)
|
| 440 |
+
z = outputs.last_hidden_state
|
| 441 |
+
return z
|
| 442 |
+
|
| 443 |
+
@torch.no_grad()
|
| 444 |
+
def encode(self, text, normalize=False):
|
| 445 |
+
outputs = self(text)
|
| 446 |
+
if normalize:
|
| 447 |
+
outputs = outputs / outputs.norm(dim=-1, keepdim=True)
|
| 448 |
+
return outputs
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class ScalarEmbedder(nn.Module):
|
| 452 |
+
"""embeds each dimension independently and concatenates them"""
|
| 453 |
+
|
| 454 |
+
def __init__(self, embed_dim, out_dim):
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.timestep = Timestep(embed_dim)
|
| 457 |
+
self.embed_layer = nn.Sequential(
|
| 458 |
+
nn.Linear(embed_dim, out_dim),
|
| 459 |
+
nn.SiLU(),
|
| 460 |
+
zero_module(nn.Linear(out_dim, out_features=out_dim))
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
def forward(self, x, dtype=torch.float32):
|
| 464 |
+
emb = self.timestep(x)
|
| 465 |
+
emb = rearrange(emb, "b d -> b 1 d")
|
| 466 |
+
emb = self.embed_layer(emb.to(dtype))
|
| 467 |
+
return emb
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class TimestepEmbedding(nn.Module):
|
| 471 |
+
def __init__(self, embed_dim):
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.timestep = Timestep(embed_dim)
|
| 474 |
+
|
| 475 |
+
def forward(self, x):
|
| 476 |
+
x = self.timestep(x)
|
| 477 |
+
return x
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if __name__ == '__main__':
|
| 481 |
+
import PIL.Image as Image
|
| 482 |
+
|
| 483 |
+
encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H")
|
| 484 |
+
image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB")
|
| 485 |
+
image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2
|
| 486 |
+
image = image.unsqueeze(0)
|
| 487 |
+
print(image.shape)
|
| 488 |
+
feat = encoder.encode(image, "local")
|
| 489 |
+
print(feat.shape)
|
refnet/modules/encoder.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from refnet.util import checkpoint_wrapper
|
| 6 |
+
from refnet.modules.unet import TimestepEmbedSequential
|
| 7 |
+
from refnet.modules.layers import Upsample, zero_module, RMSNorm, FeedForward
|
| 8 |
+
from refnet.modules.attention import MemoryEfficientAttention, MultiScaleCausalAttention
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_zero_conv(in_channels, out_channels=None):
|
| 15 |
+
out_channels = out_channels or in_channels
|
| 16 |
+
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
|
| 17 |
+
|
| 18 |
+
def activate_zero_conv(in_channels, out_channels=None):
|
| 19 |
+
out_channels = out_channels or in_channels
|
| 20 |
+
return TimestepEmbedSequential(
|
| 21 |
+
nn.SiLU(),
|
| 22 |
+
zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def sequential_downsample(in_channels, out_channels, sequential_cls=nn.Sequential):
|
| 26 |
+
return sequential_cls(
|
| 27 |
+
nn.Conv2d(in_channels, 16, 3, padding=1),
|
| 28 |
+
nn.SiLU(),
|
| 29 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
| 30 |
+
nn.SiLU(),
|
| 31 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
| 32 |
+
nn.SiLU(),
|
| 33 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
| 34 |
+
nn.SiLU(),
|
| 35 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
| 36 |
+
nn.SiLU(),
|
| 37 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
| 38 |
+
nn.SiLU(),
|
| 39 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
| 40 |
+
nn.SiLU(),
|
| 41 |
+
zero_module(nn.Conv2d(256, out_channels, 3, padding=1))
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SimpleEncoder(nn.Module):
|
| 46 |
+
def __init__(self, c_channels, model_channels):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.model = sequential_downsample(c_channels, model_channels)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, *args, **kwargs):
|
| 51 |
+
return self.model(x)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiEncoder(nn.Module):
|
| 55 |
+
def __init__(self, in_ch, model_channels, ch_mults, checkpoint=True, time_embed=False):
|
| 56 |
+
super().__init__()
|
| 57 |
+
sequential_cls = TimestepEmbedSequential if time_embed else nn.Sequential
|
| 58 |
+
output_chs = [model_channels * mult for mult in ch_mults]
|
| 59 |
+
self.model = sequential_downsample(in_ch, model_channels, sequential_cls)
|
| 60 |
+
self.zero_layer = make_zero_conv(output_chs[0])
|
| 61 |
+
self.output_blocks = nn.ModuleList()
|
| 62 |
+
self.zero_blocks = nn.ModuleList()
|
| 63 |
+
|
| 64 |
+
block_num = len(ch_mults)
|
| 65 |
+
prev_ch = output_chs[0]
|
| 66 |
+
for i in range(block_num):
|
| 67 |
+
self.output_blocks.append(sequential_cls(
|
| 68 |
+
nn.SiLU(),
|
| 69 |
+
nn.Conv2d(prev_ch, output_chs[i], 3, padding=1, stride=2 if i != block_num-1 else 1),
|
| 70 |
+
nn.SiLU(),
|
| 71 |
+
nn.Conv2d(output_chs[i], output_chs[i], 3, padding=1)
|
| 72 |
+
))
|
| 73 |
+
self.zero_blocks.append(
|
| 74 |
+
TimestepEmbedSequential(make_zero_conv(output_chs[i])) if time_embed
|
| 75 |
+
else make_zero_conv(output_chs[i])
|
| 76 |
+
)
|
| 77 |
+
prev_ch = output_chs[i]
|
| 78 |
+
|
| 79 |
+
self.checkpoint = checkpoint
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
x = self.model(x)
|
| 83 |
+
hints = [self.zero_layer(x)]
|
| 84 |
+
for layer, zero_layer in zip(self.output_blocks, self.zero_blocks):
|
| 85 |
+
x = layer(x)
|
| 86 |
+
hints.append(zero_layer(x))
|
| 87 |
+
return hints
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class MultiScaleAttentionEncoder(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_ch,
|
| 94 |
+
model_channels,
|
| 95 |
+
ch_mults,
|
| 96 |
+
dim_head = 128,
|
| 97 |
+
transformer_layers = 2,
|
| 98 |
+
checkpoint = True
|
| 99 |
+
):
|
| 100 |
+
super().__init__()
|
| 101 |
+
conv_proj = partial(nn.Conv2d, kernel_size=1, padding=0)
|
| 102 |
+
output_chs = [model_channels * mult for mult in ch_mults]
|
| 103 |
+
block_num = len(ch_mults)
|
| 104 |
+
attn_ch = output_chs[-1]
|
| 105 |
+
|
| 106 |
+
self.model = sequential_downsample(in_ch, output_chs[0])
|
| 107 |
+
self.proj_ins = nn.ModuleList([conv_proj(output_chs[0], attn_ch)])
|
| 108 |
+
self.proj_outs = nn.ModuleList([zero_module(conv_proj(attn_ch, output_chs[0]))])
|
| 109 |
+
|
| 110 |
+
prev_ch = output_chs[0]
|
| 111 |
+
self.downsample_layers = nn.ModuleList()
|
| 112 |
+
for i in range(block_num):
|
| 113 |
+
ch = output_chs[i]
|
| 114 |
+
self.downsample_layers.append(nn.Sequential(
|
| 115 |
+
nn.SiLU(),
|
| 116 |
+
nn.Conv2d(prev_ch, ch, 3, padding=1, stride=2 if i != block_num - 1 else 1),
|
| 117 |
+
))
|
| 118 |
+
self.proj_ins.append(conv_proj(ch, attn_ch))
|
| 119 |
+
self.proj_outs.append(zero_module(conv_proj(attn_ch, ch)))
|
| 120 |
+
prev_ch = ch
|
| 121 |
+
|
| 122 |
+
self.proj_ins.append(conv_proj(attn_ch, attn_ch))
|
| 123 |
+
self.attn_layer = MultiScaleCausalAttention(attn_ch, rope=True, qk_norm=True, dim_head=dim_head)
|
| 124 |
+
# self.transformer = nn.ModuleList([
|
| 125 |
+
# BasicTransformerBlock(
|
| 126 |
+
# attn_ch,
|
| 127 |
+
# rotary_positional_embedding = True,
|
| 128 |
+
# qk_norm = True,
|
| 129 |
+
# d_head = dim_head,
|
| 130 |
+
# disable_cross_attn = True,
|
| 131 |
+
# self_attn_type = "multi-scale",
|
| 132 |
+
# ff_mult = 2,
|
| 133 |
+
# )
|
| 134 |
+
# ] * transformer_layers)
|
| 135 |
+
self.checkpoint = checkpoint
|
| 136 |
+
|
| 137 |
+
@checkpoint_wrapper
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
proj_in_iter = iter(self.proj_ins)
|
| 140 |
+
proj_out_iter = iter(self.proj_outs[::-1])
|
| 141 |
+
|
| 142 |
+
x = self.model(x)
|
| 143 |
+
hints = [rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c")]
|
| 144 |
+
grid_sizes = [(x.shape[2], x.shape[3])]
|
| 145 |
+
token_lens = [(x.shape[2] * x.shape[3])]
|
| 146 |
+
|
| 147 |
+
for layer in self.downsample_layers:
|
| 148 |
+
x = layer(x)
|
| 149 |
+
h, w = x.shape[2], x.shape[3]
|
| 150 |
+
grid_sizes.append((h, w))
|
| 151 |
+
token_lens.append(h * w)
|
| 152 |
+
hints.append(rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c"))
|
| 153 |
+
|
| 154 |
+
hints.append(rearrange(
|
| 155 |
+
next(proj_in_iter)(x.mean(dim=[2, 3], keepdim=True)),
|
| 156 |
+
"b c h w -> b (h w) c"
|
| 157 |
+
))
|
| 158 |
+
|
| 159 |
+
hints = hints[::-1]
|
| 160 |
+
grid_sizes = grid_sizes[::-1]
|
| 161 |
+
token_lens = token_lens[::-1]
|
| 162 |
+
hints = torch.cat(hints, 1)
|
| 163 |
+
hints = self.attn_layer(hints, grid_size=grid_sizes, token_lens=token_lens) + hints
|
| 164 |
+
# for layer in self.transformer:
|
| 165 |
+
# hints = layer(hints, grid_size=grid_sizes, token_lens=token_lens)
|
| 166 |
+
|
| 167 |
+
prev_idx = 1
|
| 168 |
+
controls = []
|
| 169 |
+
for gs, token_len in zip(grid_sizes, token_lens):
|
| 170 |
+
control = hints[:, prev_idx: prev_idx + token_len]
|
| 171 |
+
control = rearrange(control, "b (h w) c -> b c h w", h=gs[0], w=gs[1])
|
| 172 |
+
controls.append(next(proj_out_iter)(control))
|
| 173 |
+
prev_idx = prev_idx + token_len
|
| 174 |
+
return controls[::-1]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class Downsampler(nn.Module):
|
| 179 |
+
def __init__(self, scale_factor):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.scale_factor = scale_factor
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode="bicubic")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class SpatialConditionEncoder(nn.Module):
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
in_dim,
|
| 191 |
+
dim,
|
| 192 |
+
out_dim,
|
| 193 |
+
patch_size,
|
| 194 |
+
n_layers = 4,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.patch_embed = nn.Conv2d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 198 |
+
self.conv = nn.Sequential(nn.SiLU(), nn.Conv2d(dim, dim, kernel_size=3, padding=1))
|
| 199 |
+
|
| 200 |
+
self.transformer = nn.ModuleList(
|
| 201 |
+
nn.ModuleList([
|
| 202 |
+
RMSNorm(dim),
|
| 203 |
+
MemoryEfficientAttention(dim, rope=True),
|
| 204 |
+
RMSNorm(dim),
|
| 205 |
+
FeedForward(dim, mult=2)
|
| 206 |
+
]) for _ in range(n_layers)
|
| 207 |
+
)
|
| 208 |
+
self.out = nn.Sequential(
|
| 209 |
+
nn.SiLU(),
|
| 210 |
+
zero_module(nn.Conv2d(dim, out_dim, kernel_size=1, padding=0))
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
x = self.patch_embed(x)
|
| 215 |
+
x = self.conv(x)
|
| 216 |
+
|
| 217 |
+
b, c, h, w = x.shape
|
| 218 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 219 |
+
for norm, layer, norm2, ff in self.transformer:
|
| 220 |
+
x = layer(norm(x), grid_size=(h, w)) + x
|
| 221 |
+
x = ff(norm2(x)) + x
|
| 222 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 223 |
+
|
| 224 |
+
return self.out(x)
|
refnet/modules/layers.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from refnet.util import default
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RMSNorm(nn.Module):
|
| 12 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.eps = eps
|
| 15 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 16 |
+
|
| 17 |
+
def _norm(self, x):
|
| 18 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
output = self._norm(x.float()).type_as(x)
|
| 22 |
+
return output * self.weight
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def init_(tensor):
|
| 27 |
+
dim = tensor.shape[-1]
|
| 28 |
+
std = 1 / math.sqrt(dim)
|
| 29 |
+
tensor.uniform_(-std, std)
|
| 30 |
+
return tensor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# feedforward
|
| 34 |
+
class GEGLU(nn.Module):
|
| 35 |
+
def __init__(self, dim_in, dim_out):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 41 |
+
return x * F.gelu(gate)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FeedForward(nn.Module):
|
| 45 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
| 46 |
+
super().__init__()
|
| 47 |
+
inner_dim = int(dim * mult)
|
| 48 |
+
dim_out = default(dim_out, dim)
|
| 49 |
+
project_in = nn.Sequential(
|
| 50 |
+
nn.Linear(dim, inner_dim),
|
| 51 |
+
nn.GELU()
|
| 52 |
+
) if not glu else GEGLU(dim, inner_dim)
|
| 53 |
+
|
| 54 |
+
self.net = nn.Sequential(
|
| 55 |
+
project_in,
|
| 56 |
+
nn.Dropout(dropout),
|
| 57 |
+
nn.Linear(inner_dim, dim_out)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
return self.net(x)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def zero_module(module):
|
| 65 |
+
"""
|
| 66 |
+
Zero out the parameters of a module and return it.
|
| 67 |
+
"""
|
| 68 |
+
for p in module.parameters():
|
| 69 |
+
p.detach().zero_()
|
| 70 |
+
return module
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def Normalize(in_channels):
|
| 74 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Upsample(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
An upsampling layer with an optional convolution.
|
| 80 |
+
:param channels: channels in the inputs and outputs.
|
| 81 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 82 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 83 |
+
upsampling occurs in the inner-two dimensions.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, channels, use_conv, out_channels=None, padding=1):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.channels = channels
|
| 89 |
+
self.out_channels = out_channels or channels
|
| 90 |
+
self.use_conv = use_conv
|
| 91 |
+
if use_conv:
|
| 92 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
assert x.shape[1] == self.channels
|
| 96 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 97 |
+
if self.use_conv:
|
| 98 |
+
x = self.conv(x)
|
| 99 |
+
return x
|
refnet/modules/lora.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from typing import Union, Dict, List
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from refnet.util import exists, default
|
| 8 |
+
from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_module_safe(self, module_path: str):
|
| 12 |
+
current_module = self
|
| 13 |
+
try:
|
| 14 |
+
for part in module_path.split('.'):
|
| 15 |
+
current_module = getattr(current_module, part)
|
| 16 |
+
return current_module
|
| 17 |
+
except AttributeError:
|
| 18 |
+
raise AttributeError(f"Cannot find modules {module_path}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def switch_lora(self, v, label=None):
|
| 22 |
+
for t in [self.to_q, self.to_k, self.to_v]:
|
| 23 |
+
t.set_lora_active(v, label)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def lora_forward(self, x, context, mask, scale=1., scale_factor= None):
|
| 27 |
+
def qkv_forward(x, context):
|
| 28 |
+
q = self.to_q(x)
|
| 29 |
+
k = self.to_k(context)
|
| 30 |
+
v = self.to_v(context)
|
| 31 |
+
return q, k, v
|
| 32 |
+
|
| 33 |
+
assert exists(scale_factor), "Scale factor must be assigned before masked attention"
|
| 34 |
+
|
| 35 |
+
mask = rearrange(
|
| 36 |
+
F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
|
| 37 |
+
"b c h w -> b (h w) c"
|
| 38 |
+
).contiguous()
|
| 39 |
+
|
| 40 |
+
c1, c2 = context.chunk(2, dim=1)
|
| 41 |
+
|
| 42 |
+
# Background region cross-attention
|
| 43 |
+
if self.use_lora:
|
| 44 |
+
self.switch_lora(False, "foreground")
|
| 45 |
+
q2, k2, v2 = qkv_forward(x, c2)
|
| 46 |
+
bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
|
| 47 |
+
|
| 48 |
+
# Character region cross-attention
|
| 49 |
+
if self.use_lora:
|
| 50 |
+
self.switch_lora(True, "foreground")
|
| 51 |
+
q1, k1, v1 = qkv_forward(x, c1)
|
| 52 |
+
fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
|
| 53 |
+
|
| 54 |
+
fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
|
| 55 |
+
return fg_out * mask + bg_out * (1 - mask)
|
| 56 |
+
# return torch.where(mask > self.mask_threshold, fg_out, bg_out)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None):
|
| 60 |
+
"""
|
| 61 |
+
This function hacks cross-attention layers.
|
| 62 |
+
Args:
|
| 63 |
+
x: Query input
|
| 64 |
+
context: Key and value input
|
| 65 |
+
mask: Character mask
|
| 66 |
+
scale: Attention scale
|
| 67 |
+
sacle_factor: Current latent size factor
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
def qkv_forward(x, context):
|
| 71 |
+
q = self.to_q(x)
|
| 72 |
+
k = self.to_k(context)
|
| 73 |
+
v = self.to_v(context)
|
| 74 |
+
return q, k, v
|
| 75 |
+
|
| 76 |
+
assert exists(scale_factor), "Scale factor must be assigned before masked attention"
|
| 77 |
+
|
| 78 |
+
mask = rearrange(
|
| 79 |
+
F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
|
| 80 |
+
"b c h w -> b (h w) c"
|
| 81 |
+
).contiguous()
|
| 82 |
+
|
| 83 |
+
c1, c2 = context.chunk(2, dim=1)
|
| 84 |
+
|
| 85 |
+
# Background region cross-attention
|
| 86 |
+
if self.use_lora:
|
| 87 |
+
self.switch_lora(True, "background")
|
| 88 |
+
self.switch_lora(False, "foreground")
|
| 89 |
+
q2, k2, v2 = qkv_forward(x, c2)
|
| 90 |
+
bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
|
| 91 |
+
|
| 92 |
+
# Foreground region cross-attention
|
| 93 |
+
if self.use_lora:
|
| 94 |
+
self.switch_lora(False, "background")
|
| 95 |
+
self.switch_lora(True, "foreground")
|
| 96 |
+
q1, k1, v1 = qkv_forward(x, c1)
|
| 97 |
+
fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
|
| 98 |
+
|
| 99 |
+
fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
|
| 100 |
+
# return fg_out * mask + bg_out * (1 - mask)
|
| 101 |
+
return torch.where(mask > self.mask_threshold, fg_out, bg_out)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class MultiLoraInjectedLinear(nn.Linear):
|
| 106 |
+
"""
|
| 107 |
+
A linear layer that can hold multiple LoRA adapters and merge them.
|
| 108 |
+
"""
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
in_features,
|
| 112 |
+
out_features,
|
| 113 |
+
bias = False,
|
| 114 |
+
):
|
| 115 |
+
super().__init__(in_features, out_features, bias)
|
| 116 |
+
self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}}
|
| 117 |
+
self.lora_scales: Dict[str, float] = {}
|
| 118 |
+
self.active_loras: Dict[str, bool] = {}
|
| 119 |
+
self.original_weight = None
|
| 120 |
+
self.original_bias = None
|
| 121 |
+
|
| 122 |
+
# Freeze original weights
|
| 123 |
+
self.weight.requires_grad_(False)
|
| 124 |
+
if exists(self.bias):
|
| 125 |
+
self.bias.requires_grad_(False)
|
| 126 |
+
|
| 127 |
+
def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0):
|
| 128 |
+
"""Add a new LoRA adapter with the given label."""
|
| 129 |
+
if isinstance(r, float):
|
| 130 |
+
r = int(r * self.out_features)
|
| 131 |
+
|
| 132 |
+
lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None)
|
| 133 |
+
lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None)
|
| 134 |
+
dropout = nn.Dropout(dropout_p)
|
| 135 |
+
|
| 136 |
+
# Initialize weights
|
| 137 |
+
nn.init.normal_(lora_down.weight, std=1 / r)
|
| 138 |
+
nn.init.zeros_(lora_up.weight)
|
| 139 |
+
|
| 140 |
+
self.lora_adapters[label] = {
|
| 141 |
+
'down': lora_down,
|
| 142 |
+
'up': lora_up,
|
| 143 |
+
'dropout': dropout,
|
| 144 |
+
}
|
| 145 |
+
self.lora_scales[label] = scale
|
| 146 |
+
self.active_loras[label] = True
|
| 147 |
+
|
| 148 |
+
# Register as submodules
|
| 149 |
+
self.add_module(f'lora_down_{label}', lora_down)
|
| 150 |
+
self.add_module(f'lora_up_{label}', lora_up)
|
| 151 |
+
self.add_module(f'lora_dropout_{label}', dropout)
|
| 152 |
+
|
| 153 |
+
def get_trainable_layers(self, label: str = None):
|
| 154 |
+
"""Get trainable layers for specific LoRA or all LoRAs."""
|
| 155 |
+
layers = []
|
| 156 |
+
if exists(label):
|
| 157 |
+
if label in self.lora_adapters:
|
| 158 |
+
adapter = self.lora_adapters[label]
|
| 159 |
+
layers.extend([adapter['down'], adapter['up']])
|
| 160 |
+
else:
|
| 161 |
+
for adapter in self.lora_adapters.values():
|
| 162 |
+
layers.extend([adapter['down'], adapter['up']])
|
| 163 |
+
return layers
|
| 164 |
+
|
| 165 |
+
def set_lora_active(self, active: bool, label: str):
|
| 166 |
+
"""Activate or deactivate a specific LoRA adapter."""
|
| 167 |
+
if label in self.active_loras:
|
| 168 |
+
self.active_loras[label] = active
|
| 169 |
+
|
| 170 |
+
def set_lora_scale(self, scale: float, label: str):
|
| 171 |
+
"""Set the scale for a specific LoRA adapter."""
|
| 172 |
+
if label in self.lora_scales:
|
| 173 |
+
self.lora_scales[label] = scale
|
| 174 |
+
|
| 175 |
+
def merge_lora_weights(self, labels: List[str] = None):
|
| 176 |
+
"""Merge specified LoRA adapters into the base weights."""
|
| 177 |
+
if labels is None:
|
| 178 |
+
labels = list(self.lora_adapters.keys())
|
| 179 |
+
|
| 180 |
+
# Store original weights if not already stored
|
| 181 |
+
if self.original_weight is None:
|
| 182 |
+
self.original_weight = self.weight.clone()
|
| 183 |
+
if exists(self.bias):
|
| 184 |
+
self.original_bias = self.bias.clone()
|
| 185 |
+
|
| 186 |
+
merged_weight = self.original_weight.clone()
|
| 187 |
+
merged_bias = self.original_bias.clone() if exists(self.original_bias) else None
|
| 188 |
+
|
| 189 |
+
for label in labels:
|
| 190 |
+
if label in self.lora_adapters and self.active_loras.get(label, False):
|
| 191 |
+
lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down']
|
| 192 |
+
scale = self.lora_scales[label]
|
| 193 |
+
|
| 194 |
+
lora_weight = lora_up.weight @ lora_down.weight
|
| 195 |
+
merged_weight += scale * lora_weight
|
| 196 |
+
|
| 197 |
+
if exists(merged_bias) and exists(lora_up.bias):
|
| 198 |
+
lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias
|
| 199 |
+
merged_bias += scale * lora_bias
|
| 200 |
+
|
| 201 |
+
# Update weights
|
| 202 |
+
self.weight = nn.Parameter(merged_weight, requires_grad=False)
|
| 203 |
+
if exists(merged_bias):
|
| 204 |
+
self.bias = nn.Parameter(merged_bias, requires_grad=False)
|
| 205 |
+
|
| 206 |
+
# Deactivate all LoRAs after merging
|
| 207 |
+
for label in labels:
|
| 208 |
+
self.active_loras[label] = False
|
| 209 |
+
|
| 210 |
+
def recover_original_weight(self):
|
| 211 |
+
"""Recover the original weights before any LoRA modifications."""
|
| 212 |
+
if self.original_weight is not None:
|
| 213 |
+
self.weight = nn.Parameter(self.original_weight.clone())
|
| 214 |
+
if exists(self.original_bias):
|
| 215 |
+
self.bias = nn.Parameter(self.original_bias.clone())
|
| 216 |
+
|
| 217 |
+
# Reactivate all LoRAs
|
| 218 |
+
for label in self.active_loras:
|
| 219 |
+
self.active_loras[label] = True
|
| 220 |
+
|
| 221 |
+
def forward(self, input):
|
| 222 |
+
output = super().forward(input)
|
| 223 |
+
|
| 224 |
+
# Add contributions from active LoRAs
|
| 225 |
+
for label, adapter in self.lora_adapters.items():
|
| 226 |
+
if self.active_loras.get(label, False):
|
| 227 |
+
lora_out = adapter['up'](adapter['dropout'](adapter['down'](input)))
|
| 228 |
+
output += self.lora_scales[label] * lora_out
|
| 229 |
+
|
| 230 |
+
return output
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class LoraModules:
|
| 234 |
+
def __init__(self, sd, lora_params, *args, **kwargs):
|
| 235 |
+
self.modules = {}
|
| 236 |
+
self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer
|
| 237 |
+
|
| 238 |
+
for cfg in lora_params:
|
| 239 |
+
root_module = get_module_safe(sd, cfg.pop("root_module"))
|
| 240 |
+
label = cfg.pop("label", "lora")
|
| 241 |
+
self.inject_lora(label, root_module, **cfg)
|
| 242 |
+
|
| 243 |
+
def inject_lora(
|
| 244 |
+
self,
|
| 245 |
+
label,
|
| 246 |
+
root_module,
|
| 247 |
+
r,
|
| 248 |
+
split_forward = False,
|
| 249 |
+
target_keys = ("to_q", "to_k", "to_v"),
|
| 250 |
+
filter_keys = None,
|
| 251 |
+
target_class = None,
|
| 252 |
+
scale = 1.0,
|
| 253 |
+
dropout_p = 0.0,
|
| 254 |
+
):
|
| 255 |
+
def check_condition(path, child, class_list):
|
| 256 |
+
if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys):
|
| 257 |
+
return False
|
| 258 |
+
if exists(target_keys) and any(path.endswith(key) for key in target_keys):
|
| 259 |
+
return True
|
| 260 |
+
if exists(class_list) and any(
|
| 261 |
+
isinstance(child, module_class) for module_class in class_list
|
| 262 |
+
):
|
| 263 |
+
return True
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
def retrieve_target_modules():
|
| 267 |
+
from refnet.util import get_obj_from_str
|
| 268 |
+
target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None
|
| 269 |
+
|
| 270 |
+
modules = []
|
| 271 |
+
for name, module in root_module.named_modules():
|
| 272 |
+
for key, child in module._modules.items():
|
| 273 |
+
full_path = name + '.' + key if name else key
|
| 274 |
+
if check_condition(full_path, child, target_class_list):
|
| 275 |
+
modules.append((module, child, key, full_path))
|
| 276 |
+
return modules
|
| 277 |
+
|
| 278 |
+
modules: list[Union[nn.Module]] = []
|
| 279 |
+
retrieved_modules = retrieve_target_modules()
|
| 280 |
+
|
| 281 |
+
for parent, child, child_name, full_path in retrieved_modules:
|
| 282 |
+
# Check if this layer already has a MultiLoraInjectedLinear
|
| 283 |
+
if full_path in self.multi_lora_layers:
|
| 284 |
+
# Add LoRA to existing MultiLoraInjectedLinear
|
| 285 |
+
multi_lora_layer = self.multi_lora_layers[full_path]
|
| 286 |
+
multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
|
| 287 |
+
else:
|
| 288 |
+
# Check if the current layer is already a MultiLoraInjectedLinear
|
| 289 |
+
if isinstance(child, MultiLoraInjectedLinear):
|
| 290 |
+
child.add_lora_adapter(label, r, scale, dropout_p)
|
| 291 |
+
self.multi_lora_layers[full_path] = child
|
| 292 |
+
else:
|
| 293 |
+
# Replace with MultiLoraInjectedLinear and add first LoRA
|
| 294 |
+
multi_lora_layer = MultiLoraInjectedLinear(
|
| 295 |
+
in_features=child.weight.shape[1],
|
| 296 |
+
out_features=child.weight.shape[0],
|
| 297 |
+
bias=exists(child.bias),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
|
| 301 |
+
parent._modules[child_name] = multi_lora_layer
|
| 302 |
+
self.multi_lora_layers[full_path] = multi_lora_layer
|
| 303 |
+
|
| 304 |
+
if split_forward:
|
| 305 |
+
parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__)
|
| 306 |
+
else:
|
| 307 |
+
parent.masked_forward = lora_forward.__get__(parent, parent.__class__)
|
| 308 |
+
|
| 309 |
+
parent.use_lora = True
|
| 310 |
+
parent.switch_lora = switch_lora.__get__(parent, parent.__class__)
|
| 311 |
+
modules.append(parent)
|
| 312 |
+
|
| 313 |
+
self.modules[label] = modules
|
| 314 |
+
print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers")
|
| 315 |
+
return self.multi_lora_layers, modules
|
| 316 |
+
|
| 317 |
+
def get_trainable_layers(self, label = None):
|
| 318 |
+
"""Get all trainable layers, optionally filtered by label."""
|
| 319 |
+
layers = []
|
| 320 |
+
for lora_layer in self.multi_lora_layers.values():
|
| 321 |
+
layers += lora_layer.get_trainable_layers(label)
|
| 322 |
+
return layers
|
| 323 |
+
|
| 324 |
+
def switch_lora(self, mode, label = None):
|
| 325 |
+
if exists(label):
|
| 326 |
+
for layer in self.multi_lora_layers.values():
|
| 327 |
+
layer.set_lora_active(mode, label)
|
| 328 |
+
for module in self.modules[label]:
|
| 329 |
+
module.use_lora = mode
|
| 330 |
+
else:
|
| 331 |
+
for layer in self.multi_lora_layers.values():
|
| 332 |
+
for lora_label in layer.lora_adapters.keys():
|
| 333 |
+
layer.set_lora_active(mode, lora_label)
|
| 334 |
+
|
| 335 |
+
for modules in self.modules.values():
|
| 336 |
+
for module in modules:
|
| 337 |
+
module.use_lora = mode
|
| 338 |
+
|
| 339 |
+
def adjust_lora_scales(self, scale, label = None):
|
| 340 |
+
if exists(label):
|
| 341 |
+
for layer in self.multi_lora_layers.values():
|
| 342 |
+
layer.set_lora_scale(scale, label)
|
| 343 |
+
else:
|
| 344 |
+
for layer in self.multi_lora_layers.values():
|
| 345 |
+
for lora_label in layer.lora_adapters.keys():
|
| 346 |
+
layer.set_lora_scale(scale, lora_label)
|
| 347 |
+
|
| 348 |
+
def merge_lora(self, labels = None):
|
| 349 |
+
if labels is None:
|
| 350 |
+
labels = list(self.modules.keys())
|
| 351 |
+
elif isinstance(labels, str):
|
| 352 |
+
labels = [labels]
|
| 353 |
+
|
| 354 |
+
for layer in self.multi_lora_layers.values():
|
| 355 |
+
layer.merge_lora_weights(labels)
|
| 356 |
+
|
| 357 |
+
def recover_lora(self):
|
| 358 |
+
for layer in self.multi_lora_layers.values():
|
| 359 |
+
layer.recover_original_weight()
|
| 360 |
+
|
| 361 |
+
def get_lora_info(self):
|
| 362 |
+
"""Get information about all LoRA adapters."""
|
| 363 |
+
info = {}
|
| 364 |
+
for path, layer in self.multi_lora_layers.items():
|
| 365 |
+
info[path] = {
|
| 366 |
+
'labels': list(layer.lora_adapters.keys()),
|
| 367 |
+
'active': {label: active for label, active in layer.active_loras.items()},
|
| 368 |
+
'scales': layer.lora_scales.copy()
|
| 369 |
+
}
|
| 370 |
+
return info
|
refnet/modules/proj.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from refnet.modules.layers import zero_module
|
| 5 |
+
from refnet.modules.attention import MemoryEfficientAttention
|
| 6 |
+
from refnet.modules.transformer import BasicTransformerBlock
|
| 7 |
+
from refnet.util import checkpoint_wrapper, exists
|
| 8 |
+
from refnet.util import load_weights
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class NormalizedLinear(nn.Module):
|
| 12 |
+
def __init__(self, dim, output_dim, checkpoint=True):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.layers = nn.Sequential(
|
| 15 |
+
nn.Linear(dim, output_dim),
|
| 16 |
+
nn.LayerNorm(output_dim)
|
| 17 |
+
)
|
| 18 |
+
self.checkpoint = checkpoint
|
| 19 |
+
|
| 20 |
+
@checkpoint_wrapper
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return self.layers(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GlobalProjection(nn.Module):
|
| 26 |
+
def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.c_dim = output_dim
|
| 29 |
+
self.dim_head = dim_head
|
| 30 |
+
self.head = (heads[0], heads[0] * heads[1])
|
| 31 |
+
|
| 32 |
+
self.proj1 = nn.Linear(input_dim, dim_head * heads[0])
|
| 33 |
+
self.proj2 = nn.Sequential(
|
| 34 |
+
nn.SiLU(),
|
| 35 |
+
zero_module(nn.Linear(dim_head, output_dim * heads[1])),
|
| 36 |
+
)
|
| 37 |
+
self.norm = nn.LayerNorm(output_dim)
|
| 38 |
+
self.checkpoint = checkpoint
|
| 39 |
+
|
| 40 |
+
@checkpoint_wrapper
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous()
|
| 43 |
+
x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous()
|
| 44 |
+
return self.norm(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ClusterConcat(nn.Module):
|
| 48 |
+
def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head)
|
| 51 |
+
self.norm = nn.LayerNorm(input_dim)
|
| 52 |
+
self.proj = nn.Sequential(
|
| 53 |
+
nn.Linear(input_dim + c_dim, output_dim),
|
| 54 |
+
nn.SiLU(),
|
| 55 |
+
nn.Linear(output_dim, output_dim),
|
| 56 |
+
nn.LayerNorm(output_dim)
|
| 57 |
+
)
|
| 58 |
+
self.token_length = token_length
|
| 59 |
+
self.checkpoint = checkpoint
|
| 60 |
+
|
| 61 |
+
@checkpoint_wrapper
|
| 62 |
+
def forward(self, x, emb, fgbg=False, *args, **kwargs):
|
| 63 |
+
x = self.attn(x)[:, :self.token_length]
|
| 64 |
+
x = self.norm(x)
|
| 65 |
+
x = torch.cat([x, emb], 2)
|
| 66 |
+
x = self.proj(x)
|
| 67 |
+
|
| 68 |
+
if fgbg:
|
| 69 |
+
x = torch.cat(torch.chunk(x, 2), 1)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RecoveryClusterConcat(ClusterConcat):
|
| 74 |
+
def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs):
|
| 75 |
+
super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs)
|
| 76 |
+
self.transformer = BasicTransformerBlock(
|
| 77 |
+
output_dim, output_dim//dim_head, dim_head,
|
| 78 |
+
disable_cross_attn=True, checkpoint=False
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@checkpoint_wrapper
|
| 82 |
+
def forward(self, x, emb, bg=False):
|
| 83 |
+
x = self.attn(x)[:, :self.token_length]
|
| 84 |
+
x = self.norm(x)
|
| 85 |
+
x = torch.cat([x, emb], 2)
|
| 86 |
+
x = self.proj(x)
|
| 87 |
+
|
| 88 |
+
if bg:
|
| 89 |
+
x = self.transformer(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class LogitClusterConcat(ClusterConcat):
|
| 94 |
+
def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs):
|
| 95 |
+
super().__init__(c_dim=c_dim, *args, **kwargs)
|
| 96 |
+
self.mlp = AdaptiveMLP(c_dim, mlp_in_dim)
|
| 97 |
+
if exists(mlp_ckpt_path):
|
| 98 |
+
self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True)
|
| 99 |
+
|
| 100 |
+
@checkpoint_wrapper
|
| 101 |
+
def forward(self, x, emb, bg=False):
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
emb = self.mlp(emb).detach()
|
| 104 |
+
return super().forward(x, emb, bg)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class AdaptiveMLP(nn.Module):
|
| 108 |
+
def __init__(self, dim, in_dim, layers=4, checkpoint=True):
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
model = [nn.Sequential(nn.Linear(in_dim, dim))]
|
| 112 |
+
for i in range(1, layers):
|
| 113 |
+
model += [nn.Sequential(
|
| 114 |
+
nn.SiLU(),
|
| 115 |
+
nn.LayerNorm(dim),
|
| 116 |
+
nn.Linear(dim, dim)
|
| 117 |
+
)]
|
| 118 |
+
self.mlp = nn.Sequential(*model)
|
| 119 |
+
self.fusion_layer = nn.Linear(dim * layers, dim, bias=False)
|
| 120 |
+
self.norm = nn.LayerNorm(dim)
|
| 121 |
+
self.checkpoint = checkpoint
|
| 122 |
+
|
| 123 |
+
@checkpoint_wrapper
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
fx = []
|
| 126 |
+
|
| 127 |
+
for layer in self.mlp:
|
| 128 |
+
x = layer(x)
|
| 129 |
+
fx.append(x)
|
| 130 |
+
|
| 131 |
+
x = torch.cat(fx, dim=2)
|
| 132 |
+
out = self.fusion_layer(x)
|
| 133 |
+
out = self.norm(out)
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Concat(nn.Module):
|
| 138 |
+
def __init__(self, *args, **kwargs):
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
def forward(self, x, y, *args, **kwargs):
|
| 142 |
+
return torch.cat([x, y], dim=-1)
|
refnet/modules/reference_net.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from typing import Union
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from refnet.modules.unet_old import (
|
| 10 |
+
timestep_embedding,
|
| 11 |
+
conv_nd,
|
| 12 |
+
TimestepEmbedSequential,
|
| 13 |
+
exists,
|
| 14 |
+
ResBlock,
|
| 15 |
+
linear,
|
| 16 |
+
Downsample,
|
| 17 |
+
zero_module,
|
| 18 |
+
SelfTransformerBlock,
|
| 19 |
+
SpatialTransformer,
|
| 20 |
+
)
|
| 21 |
+
from refnet.modules.unet import DualCondUNetXL
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def hack_inference_forward(model):
|
| 25 |
+
model.forward = InferenceForward.__get__(model, model.__class__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def hack_unet_forward(unet):
|
| 29 |
+
unet.original_forward = unet._forward
|
| 30 |
+
if isinstance(unet, DualCondUNetXL):
|
| 31 |
+
unet._forward = enhanced_forward_xl.__get__(unet, unet.__class__)
|
| 32 |
+
else:
|
| 33 |
+
unet._forward = enhanced_forward.__get__(unet, unet.__class__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def restore_unet_forward(unet):
|
| 37 |
+
if hasattr(unet, "original_forward"):
|
| 38 |
+
unet._forward = unet.original_forward.__get__(unet, unet.__class__)
|
| 39 |
+
del unet.original_forward
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def modulation(x, scale, shift):
|
| 43 |
+
return x * (1 + scale) + shift
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def enhanced_forward(
|
| 47 |
+
self,
|
| 48 |
+
x: torch.Tensor,
|
| 49 |
+
emb: torch.Tensor,
|
| 50 |
+
hs_fg: torch.Tensor = None,
|
| 51 |
+
hs_bg: torch.Tensor = None,
|
| 52 |
+
mask: torch.Tensor = None,
|
| 53 |
+
threshold: Union[float|torch.Tensor] = None,
|
| 54 |
+
control: torch.Tensor = None,
|
| 55 |
+
context: torch.Tensor = None,
|
| 56 |
+
style_modulations: torch.Tensor = None,
|
| 57 |
+
**additional_context
|
| 58 |
+
):
|
| 59 |
+
h = x.to(self.dtype)
|
| 60 |
+
emb = emb.to(self.dtype)
|
| 61 |
+
hs = []
|
| 62 |
+
|
| 63 |
+
control_iter = iter(control)
|
| 64 |
+
for idx, module in enumerate(self.input_blocks):
|
| 65 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 66 |
+
|
| 67 |
+
if idx in self.hint_encoder_index:
|
| 68 |
+
h += next(control_iter)
|
| 69 |
+
|
| 70 |
+
hs.append(h)
|
| 71 |
+
|
| 72 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 73 |
+
|
| 74 |
+
for idx, module in enumerate(self.output_blocks):
|
| 75 |
+
h_skip = hs.pop()
|
| 76 |
+
|
| 77 |
+
if exists(mask) and exists(threshold):
|
| 78 |
+
# inject foreground/background features
|
| 79 |
+
B, C, H, W = h_skip.shape
|
| 80 |
+
cm = F.interpolate(mask, (H, W), mode="bicubic")
|
| 81 |
+
h = torch.cat([h, torch.where(
|
| 82 |
+
cm > threshold,
|
| 83 |
+
self.map_modules[idx](h_skip, hs_fg[idx]) if exists(hs_fg) else h_skip,
|
| 84 |
+
self.warp_modules[idx](h_skip, hs_bg[idx]) if exists(hs_bg) else h_skip
|
| 85 |
+
)], 1)
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
h = torch.cat([h, h_skip], 1)
|
| 89 |
+
|
| 90 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 91 |
+
|
| 92 |
+
if exists(style_modulations):
|
| 93 |
+
style_norm, emb_proj, style_proj = self.style_modules[idx]
|
| 94 |
+
style_m = style_modulations[idx] + emb_proj(emb)
|
| 95 |
+
style_m = style_proj(style_norm(style_m))[...,None,None]
|
| 96 |
+
scale, shift = style_m.chunk(2, dim=1)
|
| 97 |
+
|
| 98 |
+
h = modulation(h, scale, shift)
|
| 99 |
+
|
| 100 |
+
return h
|
| 101 |
+
|
| 102 |
+
def enhanced_forward_xl(
|
| 103 |
+
self,
|
| 104 |
+
x: torch.Tensor,
|
| 105 |
+
emb,
|
| 106 |
+
z_fg: torch.Tensor = None,
|
| 107 |
+
z_bg: torch.Tensor = None,
|
| 108 |
+
hs_fg: torch.Tensor = None,
|
| 109 |
+
hs_bg: torch.Tensor = None,
|
| 110 |
+
mask: torch.Tensor = None,
|
| 111 |
+
inject_mask: torch.Tensor = None,
|
| 112 |
+
threshold: Union[float|torch.Tensor] = None,
|
| 113 |
+
concat: torch.Tensor = None,
|
| 114 |
+
control: torch.Tensor = None,
|
| 115 |
+
context: torch.Tensor = None,
|
| 116 |
+
style_modulations: torch.Tensor = None,
|
| 117 |
+
**additional_context
|
| 118 |
+
):
|
| 119 |
+
h = x.to(self.dtype)
|
| 120 |
+
emb = emb.to(self.dtype)
|
| 121 |
+
hs = []
|
| 122 |
+
control_iter = iter(control)
|
| 123 |
+
|
| 124 |
+
if exists(concat):
|
| 125 |
+
h = torch.cat([h, concat], 1)
|
| 126 |
+
h = h + self.concat_conv(h)
|
| 127 |
+
|
| 128 |
+
for idx, module in enumerate(self.input_blocks):
|
| 129 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 130 |
+
|
| 131 |
+
if idx in self.hint_encoder_index:
|
| 132 |
+
h += next(control_iter)
|
| 133 |
+
|
| 134 |
+
if exists(z_fg):
|
| 135 |
+
h += self.conv_fg(z_fg)
|
| 136 |
+
z_fg = None
|
| 137 |
+
if exists(z_bg):
|
| 138 |
+
h += self.conv_bg(z_bg)
|
| 139 |
+
z_bg = None
|
| 140 |
+
|
| 141 |
+
hs.append(h)
|
| 142 |
+
|
| 143 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 144 |
+
|
| 145 |
+
for idx, module in enumerate(self.output_blocks):
|
| 146 |
+
h_skip = hs.pop()
|
| 147 |
+
|
| 148 |
+
if exists(inject_mask) and exists(threshold):
|
| 149 |
+
# inject foreground/background features
|
| 150 |
+
B, C, H, W = h_skip.shape
|
| 151 |
+
cm = F.interpolate(inject_mask, (H, W), mode="bicubic")
|
| 152 |
+
h = torch.cat([h, torch.where(
|
| 153 |
+
cm > threshold,
|
| 154 |
+
|
| 155 |
+
# foreground injection
|
| 156 |
+
rearrange(
|
| 157 |
+
self.map_modules[idx][0](
|
| 158 |
+
rearrange(h_skip, "b c h w -> b (h w) c"),
|
| 159 |
+
hs_fg[idx] + self.map_modules[idx][1](emb).unsqueeze(1)
|
| 160 |
+
), "b (h w) c -> b c h w", h=H, w=W
|
| 161 |
+
) + h_skip if exists(hs_fg) else h_skip,
|
| 162 |
+
|
| 163 |
+
# background injection
|
| 164 |
+
rearrange(
|
| 165 |
+
self.warp_modules[idx][0](
|
| 166 |
+
rearrange(h_skip, "b c h w -> b (h w) c"),
|
| 167 |
+
hs_bg[idx] + self.warp_modules[idx][1](emb).unsqueeze(1)
|
| 168 |
+
), "b (h w) c -> b c h w", h=H, w=W
|
| 169 |
+
) + h_skip if exists(hs_bg) else h_skip
|
| 170 |
+
)], 1)
|
| 171 |
+
|
| 172 |
+
else:
|
| 173 |
+
h = torch.cat([h, h_skip], 1)
|
| 174 |
+
|
| 175 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 176 |
+
|
| 177 |
+
if exists(style_modulations):
|
| 178 |
+
style_norm, emb_proj, style_proj = self.style_modules[idx]
|
| 179 |
+
style_m = style_modulations[idx] + emb_proj(emb)
|
| 180 |
+
style_m = style_proj(style_norm(style_m))[...,None,None]
|
| 181 |
+
scale, shift = style_m.chunk(2, dim=1)
|
| 182 |
+
|
| 183 |
+
h = modulation(h, scale, shift)
|
| 184 |
+
|
| 185 |
+
if idx in self.hint_decoder_index:
|
| 186 |
+
h += next(control_iter)
|
| 187 |
+
|
| 188 |
+
return h
|
| 189 |
+
|
| 190 |
+
def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 191 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 192 |
+
emb = self.time_embed(t_emb).to(self.dtype)
|
| 193 |
+
assert (y is not None) == (
|
| 194 |
+
self.num_classes is not None
|
| 195 |
+
), "must specify y if and only if the model is class-conditional"
|
| 196 |
+
|
| 197 |
+
if self.num_classes is not None:
|
| 198 |
+
assert y.shape[0] == x.shape[0]
|
| 199 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 200 |
+
emb = emb.to(self.dtype)
|
| 201 |
+
return self._forward(x, emb, *args, **kwargs)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class UNetEncoderXL(nn.Module):
|
| 205 |
+
transformers = {
|
| 206 |
+
"vanilla": SpatialTransformer,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
in_channels,
|
| 212 |
+
model_channels,
|
| 213 |
+
num_res_blocks,
|
| 214 |
+
attention_resolutions,
|
| 215 |
+
dropout = 0,
|
| 216 |
+
channel_mult = (1, 2, 4, 8),
|
| 217 |
+
conv_resample = True,
|
| 218 |
+
dims = 2,
|
| 219 |
+
num_classes = None,
|
| 220 |
+
use_checkpoint = False,
|
| 221 |
+
num_heads = -1,
|
| 222 |
+
num_head_channels = -1,
|
| 223 |
+
use_scale_shift_norm = False,
|
| 224 |
+
resblock_updown = False,
|
| 225 |
+
use_spatial_transformer = False, # custom transformer support
|
| 226 |
+
transformer_depth = 1, # custom transformer support
|
| 227 |
+
context_dim = None, # custom transformer support
|
| 228 |
+
disable_self_attentions = None,
|
| 229 |
+
disable_cross_attentions = None,
|
| 230 |
+
num_attention_blocks = None,
|
| 231 |
+
use_linear_in_transformer = False,
|
| 232 |
+
adm_in_channels = None,
|
| 233 |
+
transformer_type = "vanilla",
|
| 234 |
+
style_modulation = False,
|
| 235 |
+
):
|
| 236 |
+
super().__init__()
|
| 237 |
+
if use_spatial_transformer:
|
| 238 |
+
assert exists(
|
| 239 |
+
context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 240 |
+
assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
|
| 241 |
+
from omegaconf.listconfig import ListConfig
|
| 242 |
+
if type(context_dim) == ListConfig:
|
| 243 |
+
context_dim = list(context_dim)
|
| 244 |
+
|
| 245 |
+
time_embed_dim = model_channels * 4
|
| 246 |
+
resblock = partial(
|
| 247 |
+
ResBlock,
|
| 248 |
+
emb_channels=time_embed_dim,
|
| 249 |
+
dropout=dropout,
|
| 250 |
+
dims=dims,
|
| 251 |
+
use_checkpoint=use_checkpoint,
|
| 252 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
transformer = partial(
|
| 256 |
+
self.transformers[transformer_type],
|
| 257 |
+
context_dim=context_dim,
|
| 258 |
+
use_linear=use_linear_in_transformer,
|
| 259 |
+
use_checkpoint=use_checkpoint,
|
| 260 |
+
disable_self_attn=disable_self_attentions,
|
| 261 |
+
disable_cross_attn=disable_cross_attentions,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if num_heads == -1:
|
| 265 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 266 |
+
|
| 267 |
+
if num_head_channels == -1:
|
| 268 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 269 |
+
self.in_channels = in_channels
|
| 270 |
+
self.model_channels = model_channels
|
| 271 |
+
if isinstance(num_res_blocks, int):
|
| 272 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 273 |
+
else:
|
| 274 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 275 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 276 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 277 |
+
self.num_res_blocks = num_res_blocks
|
| 278 |
+
if disable_self_attentions is not None:
|
| 279 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 280 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 281 |
+
if num_attention_blocks is not None:
|
| 282 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 283 |
+
assert all(
|
| 284 |
+
map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 285 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 286 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 287 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 288 |
+
f"attention will still not be set.")
|
| 289 |
+
|
| 290 |
+
self.attention_resolutions = attention_resolutions
|
| 291 |
+
self.dropout = dropout
|
| 292 |
+
self.channel_mult = channel_mult
|
| 293 |
+
self.conv_resample = conv_resample
|
| 294 |
+
self.num_classes = num_classes
|
| 295 |
+
self.use_checkpoint = use_checkpoint
|
| 296 |
+
self.dtype = torch.float32
|
| 297 |
+
self.num_heads = num_heads
|
| 298 |
+
self.num_head_channels = num_head_channels
|
| 299 |
+
self.style_modulation = style_modulation
|
| 300 |
+
|
| 301 |
+
if isinstance(transformer_depth, int):
|
| 302 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 303 |
+
|
| 304 |
+
time_embed_dim = model_channels * 4
|
| 305 |
+
zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
|
| 306 |
+
|
| 307 |
+
self.time_embed = nn.Sequential(
|
| 308 |
+
linear(model_channels, time_embed_dim),
|
| 309 |
+
nn.SiLU(),
|
| 310 |
+
linear(time_embed_dim, time_embed_dim),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if self.num_classes is not None:
|
| 314 |
+
if isinstance(self.num_classes, int):
|
| 315 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 316 |
+
elif self.num_classes == "continuous":
|
| 317 |
+
print("setting up linear c_adm embedding layer")
|
| 318 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 319 |
+
elif self.num_classes == "sequential":
|
| 320 |
+
assert adm_in_channels is not None
|
| 321 |
+
self.label_emb = nn.Sequential(
|
| 322 |
+
nn.Sequential(
|
| 323 |
+
linear(adm_in_channels, time_embed_dim),
|
| 324 |
+
nn.SiLU(),
|
| 325 |
+
linear(time_embed_dim, time_embed_dim),
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
raise ValueError()
|
| 330 |
+
|
| 331 |
+
self.input_blocks = nn.ModuleList(
|
| 332 |
+
[
|
| 333 |
+
TimestepEmbedSequential(
|
| 334 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 335 |
+
)
|
| 336 |
+
]
|
| 337 |
+
)
|
| 338 |
+
self.zero_layers = nn.ModuleList([zero_module(
|
| 339 |
+
nn.Linear(model_channels, model_channels * 2) if style_modulation else
|
| 340 |
+
zero_conv(model_channels, model_channels)
|
| 341 |
+
)])
|
| 342 |
+
|
| 343 |
+
ch = model_channels
|
| 344 |
+
ds = 1
|
| 345 |
+
for level, mult in enumerate(channel_mult):
|
| 346 |
+
for nr in range(self.num_res_blocks[level]):
|
| 347 |
+
layers = [
|
| 348 |
+
ResBlock(
|
| 349 |
+
ch,
|
| 350 |
+
time_embed_dim,
|
| 351 |
+
dropout,
|
| 352 |
+
out_channels=mult * model_channels,
|
| 353 |
+
dims=dims,
|
| 354 |
+
use_checkpoint=use_checkpoint,
|
| 355 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 356 |
+
)
|
| 357 |
+
]
|
| 358 |
+
ch = mult * model_channels
|
| 359 |
+
if ds in attention_resolutions:
|
| 360 |
+
num_heads = ch // num_head_channels
|
| 361 |
+
|
| 362 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 363 |
+
layers.append(
|
| 364 |
+
SelfTransformerBlock(ch, num_head_channels)
|
| 365 |
+
if not use_spatial_transformer
|
| 366 |
+
else transformer(
|
| 367 |
+
ch, num_heads, num_head_channels, depth=transformer_depth[level]
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 371 |
+
self.zero_layers.append(zero_module(
|
| 372 |
+
nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
|
| 373 |
+
))
|
| 374 |
+
|
| 375 |
+
if level != len(channel_mult) - 1:
|
| 376 |
+
out_ch = ch
|
| 377 |
+
self.input_blocks.append(TimestepEmbedSequential(
|
| 378 |
+
ResBlock(
|
| 379 |
+
ch,
|
| 380 |
+
time_embed_dim,
|
| 381 |
+
dropout,
|
| 382 |
+
out_channels=out_ch,
|
| 383 |
+
dims=dims,
|
| 384 |
+
use_checkpoint=use_checkpoint,
|
| 385 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 386 |
+
down=True,
|
| 387 |
+
) if resblock_updown else Downsample(
|
| 388 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 389 |
+
)
|
| 390 |
+
))
|
| 391 |
+
self.zero_layers.append(zero_module(
|
| 392 |
+
nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
|
| 393 |
+
zero_conv(out_ch, out_ch)
|
| 394 |
+
))
|
| 395 |
+
ch = out_ch
|
| 396 |
+
ds *= 2
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def forward(self, x, timesteps = None, y = None, *args, **kwargs):
|
| 400 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
| 401 |
+
emb = self.time_embed(t_emb)
|
| 402 |
+
|
| 403 |
+
assert (y is not None) == (
|
| 404 |
+
self.num_classes is not None
|
| 405 |
+
), "must specify y if and only if the model is class-conditional"
|
| 406 |
+
if self.num_classes is not None:
|
| 407 |
+
assert y.shape[0] == x.shape[0]
|
| 408 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 409 |
+
|
| 410 |
+
hs = self._forward(x, emb, *args, **kwargs)
|
| 411 |
+
return hs
|
| 412 |
+
|
| 413 |
+
def _forward(self, x, emb, context = None, **additional_context):
|
| 414 |
+
hints = []
|
| 415 |
+
h = x.to(self.dtype)
|
| 416 |
+
|
| 417 |
+
for idx, module in enumerate(self.input_blocks):
|
| 418 |
+
h = module(h, emb, context, **additional_context)
|
| 419 |
+
|
| 420 |
+
if self.style_modulation:
|
| 421 |
+
hint = self.zero_layers[idx](h.mean(dim=[2, 3]))
|
| 422 |
+
hints.append(hint)
|
| 423 |
+
|
| 424 |
+
else:
|
| 425 |
+
hint = self.zero_layers[idx](h)
|
| 426 |
+
hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
|
| 427 |
+
hints.append(hint)
|
| 428 |
+
|
| 429 |
+
hints.reverse()
|
| 430 |
+
return hints
|
refnet/modules/transformer.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
from refnet.util import checkpoint_wrapper, exists
|
| 8 |
+
from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm
|
| 9 |
+
from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BasicTransformerBlock(nn.Module):
|
| 13 |
+
ATTENTION_MODES = {
|
| 14 |
+
"vanilla": MemoryEfficientAttention,
|
| 15 |
+
"multi-scale": MultiScaleCausalAttention,
|
| 16 |
+
"multi-modal": MultiModalAttention,
|
| 17 |
+
}
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim,
|
| 21 |
+
n_heads = None,
|
| 22 |
+
d_head = 64,
|
| 23 |
+
dropout = 0.,
|
| 24 |
+
context_dim = None,
|
| 25 |
+
gated_ff = True,
|
| 26 |
+
ff_mult = 4,
|
| 27 |
+
checkpoint = True,
|
| 28 |
+
disable_self_attn = False,
|
| 29 |
+
disable_cross_attn = False,
|
| 30 |
+
self_attn_type = "vanilla",
|
| 31 |
+
cross_attn_type = "vanilla",
|
| 32 |
+
rotary_positional_embedding = False,
|
| 33 |
+
context_dim_2 = None,
|
| 34 |
+
casual_self_attn = False,
|
| 35 |
+
casual_cross_attn = False,
|
| 36 |
+
qk_norm = False,
|
| 37 |
+
norm_type = "layer",
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
assert self_attn_type in self.ATTENTION_MODES
|
| 41 |
+
assert cross_attn_type in self.ATTENTION_MODES
|
| 42 |
+
self_attn_cls = self.ATTENTION_MODES[self_attn_type]
|
| 43 |
+
crossattn_cls = self.ATTENTION_MODES[cross_attn_type]
|
| 44 |
+
|
| 45 |
+
if norm_type == "layer":
|
| 46 |
+
norm_cls = nn.LayerNorm
|
| 47 |
+
elif norm_type == "rms":
|
| 48 |
+
norm_cls = RMSNorm
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError(f"Normalization {norm_type} is not implemented.")
|
| 51 |
+
|
| 52 |
+
self.dim = dim
|
| 53 |
+
self.disable_self_attn = disable_self_attn
|
| 54 |
+
self.disable_cross_attn = disable_cross_attn
|
| 55 |
+
|
| 56 |
+
self.attn1 = self_attn_cls(
|
| 57 |
+
query_dim = dim,
|
| 58 |
+
heads = n_heads,
|
| 59 |
+
dim_head = d_head,
|
| 60 |
+
dropout = dropout,
|
| 61 |
+
context_dim = context_dim if self.disable_self_attn else None,
|
| 62 |
+
casual = casual_self_attn,
|
| 63 |
+
rope = rotary_positional_embedding,
|
| 64 |
+
qk_norm = qk_norm
|
| 65 |
+
)
|
| 66 |
+
self.attn2 = crossattn_cls(
|
| 67 |
+
query_dim = dim,
|
| 68 |
+
context_dim = context_dim,
|
| 69 |
+
context_dim_2 = context_dim_2,
|
| 70 |
+
heads = n_heads,
|
| 71 |
+
dim_head = d_head,
|
| 72 |
+
dropout = dropout,
|
| 73 |
+
casual = casual_cross_attn
|
| 74 |
+
) if not disable_cross_attn else None
|
| 75 |
+
|
| 76 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult)
|
| 77 |
+
self.norm1 = norm_cls(dim)
|
| 78 |
+
self.norm2 = norm_cls(dim) if not disable_cross_attn else None
|
| 79 |
+
self.norm3 = norm_cls(dim)
|
| 80 |
+
self.reference_scale = 1
|
| 81 |
+
self.scale_factor = None
|
| 82 |
+
self.checkpoint = checkpoint
|
| 83 |
+
|
| 84 |
+
@checkpoint_wrapper
|
| 85 |
+
def forward(self, x, context=None, mask=None, emb=None, **kwargs):
|
| 86 |
+
x = self.attn1(self.norm1(x), **kwargs) + x
|
| 87 |
+
if not self.disable_cross_attn:
|
| 88 |
+
x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x
|
| 89 |
+
x = self.ff(self.norm3(x)) + x
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SelfInjectedTransformerBlock(BasicTransformerBlock):
|
| 94 |
+
def __init__(self, *args, **kwargs):
|
| 95 |
+
super().__init__(*args, **kwargs)
|
| 96 |
+
self.bank = None
|
| 97 |
+
self.time_proj = None
|
| 98 |
+
self.injection_type = "concat"
|
| 99 |
+
self.forward_without_bank = super().forward
|
| 100 |
+
|
| 101 |
+
@checkpoint_wrapper
|
| 102 |
+
def forward(self, x, context=None, mask=None, emb=None, **kwargs):
|
| 103 |
+
if exists(self.bank):
|
| 104 |
+
bank = self.bank
|
| 105 |
+
if bank.shape[0] != x.shape[0]:
|
| 106 |
+
bank = bank.repeat(x.shape[0], 1, 1)
|
| 107 |
+
if exists(self.time_proj) and exists(emb):
|
| 108 |
+
bank = bank + self.time_proj(emb).unsqueeze(1)
|
| 109 |
+
x_in = self.norm1(x)
|
| 110 |
+
|
| 111 |
+
self.attn1.mask_threshold = self.attn2.mask_threshold
|
| 112 |
+
x = self.attn1(
|
| 113 |
+
x_in,
|
| 114 |
+
torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank,
|
| 115 |
+
mask = mask,
|
| 116 |
+
scale_factor = self.scale_factor,
|
| 117 |
+
**kwargs
|
| 118 |
+
) + x
|
| 119 |
+
|
| 120 |
+
x = self.attn2(
|
| 121 |
+
self.norm2(x),
|
| 122 |
+
context,
|
| 123 |
+
mask = mask,
|
| 124 |
+
scale = self.reference_scale,
|
| 125 |
+
scale_factor = self.scale_factor
|
| 126 |
+
) + x
|
| 127 |
+
|
| 128 |
+
x = self.ff(self.norm3(x)) + x
|
| 129 |
+
else:
|
| 130 |
+
x = self.forward_without_bank(x, context, mask, emb)
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SelfTransformerBlock(nn.Module):
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
dim,
|
| 138 |
+
dim_head = 64,
|
| 139 |
+
dropout = 0.,
|
| 140 |
+
mlp_ratio = 4,
|
| 141 |
+
checkpoint = True,
|
| 142 |
+
casual_attn = False,
|
| 143 |
+
reshape = True
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn)
|
| 147 |
+
self.ff = nn.Sequential(
|
| 148 |
+
nn.Linear(dim, dim * mlp_ratio),
|
| 149 |
+
nn.SiLU(),
|
| 150 |
+
zero_module(nn.Linear(dim * mlp_ratio, dim))
|
| 151 |
+
)
|
| 152 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 153 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 154 |
+
self.reshape = reshape
|
| 155 |
+
self.checkpoint = checkpoint
|
| 156 |
+
|
| 157 |
+
@checkpoint_wrapper
|
| 158 |
+
def forward(self, x, context=None):
|
| 159 |
+
b, c, h, w = x.shape
|
| 160 |
+
if self.reshape:
|
| 161 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 162 |
+
|
| 163 |
+
x = self.attn(self.norm1(x), context if exists(context) else None) + x
|
| 164 |
+
x = self.ff(self.norm2(x)) + x
|
| 165 |
+
|
| 166 |
+
if self.reshape:
|
| 167 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Transformer(nn.Module):
|
| 172 |
+
transformer_type = {
|
| 173 |
+
"vanilla": BasicTransformerBlock,
|
| 174 |
+
"self-injection": SelfInjectedTransformerBlock,
|
| 175 |
+
}
|
| 176 |
+
def __init__(self, in_channels, n_heads, d_head,
|
| 177 |
+
depth=1, dropout=0., context_dim=None, use_linear=False,
|
| 178 |
+
use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs):
|
| 179 |
+
super().__init__()
|
| 180 |
+
transformer_block = self.transformer_type[type]
|
| 181 |
+
if not isinstance(context_dim, list):
|
| 182 |
+
context_dim = [context_dim]
|
| 183 |
+
if isinstance(context_dim, list):
|
| 184 |
+
if depth != len(context_dim):
|
| 185 |
+
context_dim = depth * [context_dim[0]]
|
| 186 |
+
|
| 187 |
+
proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
|
| 188 |
+
inner_dim = n_heads * d_head
|
| 189 |
+
|
| 190 |
+
self.in_channels = in_channels
|
| 191 |
+
self.proj_in = proj_layer(in_channels, inner_dim)
|
| 192 |
+
self.transformer_blocks = nn.ModuleList([
|
| 193 |
+
transformer_block(
|
| 194 |
+
inner_dim,
|
| 195 |
+
n_heads,
|
| 196 |
+
d_head,
|
| 197 |
+
dropout = dropout,
|
| 198 |
+
context_dim = context_dim[d],
|
| 199 |
+
checkpoint = use_checkpoint,
|
| 200 |
+
**(transformer_config or {}),
|
| 201 |
+
**kwargs
|
| 202 |
+
) for d in range(depth)
|
| 203 |
+
])
|
| 204 |
+
self.proj_out = zero_module(proj_layer(inner_dim, in_channels))
|
| 205 |
+
self.norm = Normalize(in_channels)
|
| 206 |
+
self.use_linear = use_linear
|
| 207 |
+
|
| 208 |
+
def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context):
|
| 209 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 210 |
+
b, c, h, w = x.shape
|
| 211 |
+
x_in = x
|
| 212 |
+
x = self.norm(x)
|
| 213 |
+
if not self.use_linear:
|
| 214 |
+
x = self.proj_in(x)
|
| 215 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 216 |
+
if self.use_linear:
|
| 217 |
+
x = self.proj_in(x)
|
| 218 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 219 |
+
x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context)
|
| 220 |
+
if self.use_linear:
|
| 221 |
+
x = self.proj_out(x)
|
| 222 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 223 |
+
if not self.use_linear:
|
| 224 |
+
x = self.proj_out(x)
|
| 225 |
+
return x + x_in
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def SpatialTransformer(*args, **kwargs):
|
| 229 |
+
return Transformer(type="vanilla", *args, **kwargs)
|
| 230 |
+
|
| 231 |
+
def SelfInjectTransformer(*args, **kwargs):
|
| 232 |
+
return Transformer(type="self-injection", *args, **kwargs)
|
refnet/modules/unet.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
from refnet.modules.attention import MemoryEfficientAttention
|
| 6 |
+
from refnet.util import exists
|
| 7 |
+
from refnet.modules.transformer import (
|
| 8 |
+
SelfTransformerBlock,
|
| 9 |
+
Transformer,
|
| 10 |
+
SpatialTransformer,
|
| 11 |
+
SelfInjectTransformer,
|
| 12 |
+
)
|
| 13 |
+
from refnet.ldm.openaimodel import (
|
| 14 |
+
timestep_embedding,
|
| 15 |
+
conv_nd,
|
| 16 |
+
TimestepBlock,
|
| 17 |
+
zero_module,
|
| 18 |
+
ResBlock,
|
| 19 |
+
linear,
|
| 20 |
+
Downsample,
|
| 21 |
+
Upsample,
|
| 22 |
+
normalization,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def hack_inference_forward(model):
|
| 27 |
+
model.forward = InferenceForward.__get__(model, model.__class__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 31 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 32 |
+
emb = self.time_embed(t_emb).to(self.dtype)
|
| 33 |
+
assert (y is not None) == (
|
| 34 |
+
self.num_classes is not None
|
| 35 |
+
), "must specify y if and only if the model is class-conditional"
|
| 36 |
+
if self.num_classes is not None:
|
| 37 |
+
assert y.shape[0] == x.shape[0]
|
| 38 |
+
emb = emb + self.label_emb(y.to(emb.device))
|
| 39 |
+
emb = emb.to(self.dtype)
|
| 40 |
+
h = self._forward(x, emb, *args, **kwargs)
|
| 41 |
+
return self.out(h.to(x.dtype))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 45 |
+
"""
|
| 46 |
+
A sequential module that passes timestep embeddings to the children that
|
| 47 |
+
support it as an extra input.
|
| 48 |
+
"""
|
| 49 |
+
# Dispatch constants
|
| 50 |
+
_D_TIMESTEP = 0
|
| 51 |
+
_D_TRANSFORMER = 1
|
| 52 |
+
_D_OTHER = 2
|
| 53 |
+
|
| 54 |
+
def __init__(self, *args, **kwargs):
|
| 55 |
+
super().__init__(*args, **kwargs)
|
| 56 |
+
# Cache dispatch types at init (before FSDP wrapping), so forward()
|
| 57 |
+
# needs no isinstance checks and is immune to FSDP wrapper breakage.
|
| 58 |
+
self._dispatch = tuple(
|
| 59 |
+
self._D_TIMESTEP if isinstance(layer, TimestepBlock) else
|
| 60 |
+
self._D_TRANSFORMER if isinstance(layer, Transformer) else
|
| 61 |
+
self._D_OTHER
|
| 62 |
+
for layer in self
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward(self, x, emb=None, context=None, mask=None, **additional_context):
|
| 66 |
+
for layer, d in zip(self, self._dispatch):
|
| 67 |
+
if d == self._D_TIMESTEP:
|
| 68 |
+
x = layer(x, emb)
|
| 69 |
+
elif d == self._D_TRANSFORMER:
|
| 70 |
+
x = layer(x, context, mask, emb, **additional_context)
|
| 71 |
+
else:
|
| 72 |
+
x = layer(x)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class UNetModel(nn.Module):
|
| 78 |
+
transformers = {
|
| 79 |
+
"vanilla": SpatialTransformer,
|
| 80 |
+
"selfinj": SelfInjectTransformer,
|
| 81 |
+
}
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
in_channels,
|
| 85 |
+
model_channels,
|
| 86 |
+
num_res_blocks,
|
| 87 |
+
attention_resolutions,
|
| 88 |
+
out_channels = 4,
|
| 89 |
+
dropout = 0,
|
| 90 |
+
channel_mult = (1, 2, 4, 8),
|
| 91 |
+
conv_resample = True,
|
| 92 |
+
dims = 2,
|
| 93 |
+
num_classes = None,
|
| 94 |
+
use_checkpoint = False,
|
| 95 |
+
num_heads = -1,
|
| 96 |
+
num_head_channels = -1,
|
| 97 |
+
use_scale_shift_norm = False,
|
| 98 |
+
resblock_updown = False,
|
| 99 |
+
use_spatial_transformer = False, # custom transformer support
|
| 100 |
+
transformer_depth = 1, # custom transformer support
|
| 101 |
+
context_dim = None, # custom transformer support
|
| 102 |
+
disable_self_attentions = None,
|
| 103 |
+
disable_cross_attentions = False,
|
| 104 |
+
num_attention_blocks = None,
|
| 105 |
+
use_linear_in_transformer = False,
|
| 106 |
+
adm_in_channels = None,
|
| 107 |
+
transformer_type = "vanilla",
|
| 108 |
+
map_module = False,
|
| 109 |
+
warp_module = False,
|
| 110 |
+
style_modulation = False,
|
| 111 |
+
discard_final_layers = False, # for reference net
|
| 112 |
+
additional_transformer_config = None,
|
| 113 |
+
in_channels_fg = None,
|
| 114 |
+
in_channels_bg = None,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
if context_dim is not None:
|
| 118 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 119 |
+
from omegaconf.listconfig import ListConfig
|
| 120 |
+
if type(context_dim) == ListConfig:
|
| 121 |
+
context_dim = list(context_dim)
|
| 122 |
+
|
| 123 |
+
assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
|
| 124 |
+
if isinstance(num_res_blocks, int):
|
| 125 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 126 |
+
else:
|
| 127 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 128 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 129 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 130 |
+
self.num_res_blocks = num_res_blocks
|
| 131 |
+
if disable_self_attentions is not None:
|
| 132 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 133 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 134 |
+
if num_attention_blocks is not None:
|
| 135 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 136 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 137 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 138 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 139 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 140 |
+
f"attention will still not be set.")
|
| 141 |
+
|
| 142 |
+
self.num_classes = num_classes
|
| 143 |
+
self.model_channels = model_channels
|
| 144 |
+
self.dtype = torch.float32
|
| 145 |
+
|
| 146 |
+
if isinstance(transformer_depth, int):
|
| 147 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 148 |
+
transformer_depth_middle = transformer_depth[-1]
|
| 149 |
+
time_embed_dim = model_channels * 4
|
| 150 |
+
resblock = partial(
|
| 151 |
+
ResBlock,
|
| 152 |
+
emb_channels = time_embed_dim,
|
| 153 |
+
dropout = dropout,
|
| 154 |
+
dims = dims,
|
| 155 |
+
use_checkpoint = use_checkpoint,
|
| 156 |
+
use_scale_shift_norm = use_scale_shift_norm,
|
| 157 |
+
)
|
| 158 |
+
transformer = partial(
|
| 159 |
+
self.transformers[transformer_type],
|
| 160 |
+
context_dim = context_dim,
|
| 161 |
+
use_linear = use_linear_in_transformer,
|
| 162 |
+
use_checkpoint = use_checkpoint,
|
| 163 |
+
disable_self_attn = disable_self_attentions,
|
| 164 |
+
disable_cross_attn = disable_cross_attentions,
|
| 165 |
+
transformer_config = additional_transformer_config
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.time_embed = nn.Sequential(
|
| 169 |
+
linear(model_channels, time_embed_dim),
|
| 170 |
+
nn.SiLU(),
|
| 171 |
+
linear(time_embed_dim, time_embed_dim),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if self.num_classes is not None:
|
| 175 |
+
if isinstance(self.num_classes, int):
|
| 176 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 177 |
+
elif self.num_classes == "continuous":
|
| 178 |
+
print("setting up linear c_adm embedding layer")
|
| 179 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 180 |
+
elif self.num_classes == "sequential":
|
| 181 |
+
assert adm_in_channels is not None
|
| 182 |
+
self.label_emb = nn.Sequential(
|
| 183 |
+
nn.Sequential(
|
| 184 |
+
linear(adm_in_channels, time_embed_dim),
|
| 185 |
+
nn.SiLU(),
|
| 186 |
+
linear(time_embed_dim, time_embed_dim),
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError()
|
| 191 |
+
|
| 192 |
+
self.input_blocks = nn.ModuleList([
|
| 193 |
+
TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
| 194 |
+
])
|
| 195 |
+
input_block_chans = [model_channels]
|
| 196 |
+
ch = model_channels
|
| 197 |
+
ds = 1
|
| 198 |
+
for level, mult in enumerate(channel_mult):
|
| 199 |
+
for nr in range(self.num_res_blocks[level]):
|
| 200 |
+
layers = [resblock(ch, out_channels=mult * model_channels)]
|
| 201 |
+
ch = mult * model_channels
|
| 202 |
+
if ds in attention_resolutions:
|
| 203 |
+
if num_head_channels > -1:
|
| 204 |
+
current_num_heads = ch // num_head_channels
|
| 205 |
+
current_head_dim = num_head_channels
|
| 206 |
+
else:
|
| 207 |
+
current_num_heads = num_heads
|
| 208 |
+
current_head_dim = ch // num_heads
|
| 209 |
+
|
| 210 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 211 |
+
layers.append(
|
| 212 |
+
SelfTransformerBlock(ch, current_head_dim)
|
| 213 |
+
if not use_spatial_transformer
|
| 214 |
+
else transformer(
|
| 215 |
+
ch, current_num_heads, current_head_dim,
|
| 216 |
+
depth=transformer_depth[level],
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 220 |
+
input_block_chans.append(ch)
|
| 221 |
+
if level != len(channel_mult) - 1:
|
| 222 |
+
out_ch = ch
|
| 223 |
+
self.input_blocks.append(TimestepEmbedSequential(
|
| 224 |
+
resblock(ch, out_channels=out_ch, down=True) if resblock_updown
|
| 225 |
+
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 226 |
+
))
|
| 227 |
+
ch = out_ch
|
| 228 |
+
input_block_chans.append(ch)
|
| 229 |
+
ds *= 2
|
| 230 |
+
|
| 231 |
+
if num_head_channels > -1:
|
| 232 |
+
current_num_heads = ch // num_head_channels
|
| 233 |
+
current_head_dim = num_head_channels
|
| 234 |
+
else:
|
| 235 |
+
current_num_heads = num_heads
|
| 236 |
+
current_head_dim = ch // num_heads
|
| 237 |
+
self.middle_block = TimestepEmbedSequential(
|
| 238 |
+
resblock(ch),
|
| 239 |
+
SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
|
| 240 |
+
else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
|
| 241 |
+
resblock(ch),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.output_blocks = nn.ModuleList([])
|
| 245 |
+
self.map_modules = nn.ModuleList([])
|
| 246 |
+
self.warp_modules = nn.ModuleList([])
|
| 247 |
+
self.style_modules = nn.ModuleList([])
|
| 248 |
+
|
| 249 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
| 250 |
+
for i in range(self.num_res_blocks[level] + 1):
|
| 251 |
+
ich = input_block_chans.pop()
|
| 252 |
+
layers = [resblock(ch + ich, out_channels=model_channels * mult)]
|
| 253 |
+
ch = model_channels * mult
|
| 254 |
+
if ds in attention_resolutions:
|
| 255 |
+
if num_head_channels > -1:
|
| 256 |
+
current_num_heads = ch // num_head_channels
|
| 257 |
+
current_head_dim = num_head_channels
|
| 258 |
+
else:
|
| 259 |
+
current_num_heads = num_heads
|
| 260 |
+
current_head_dim = ch // num_heads
|
| 261 |
+
|
| 262 |
+
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
| 263 |
+
layers.append(
|
| 264 |
+
SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
|
| 265 |
+
else transformer(
|
| 266 |
+
ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
|
| 267 |
+
)
|
| 268 |
+
)
|
| 269 |
+
if level and i == self.num_res_blocks[level]:
|
| 270 |
+
out_ch = ch
|
| 271 |
+
layers.append(
|
| 272 |
+
resblock(ch, up=True) if resblock_updown else Upsample(
|
| 273 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
ds //= 2
|
| 277 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| 278 |
+
if level == 0 and discard_final_layers:
|
| 279 |
+
break
|
| 280 |
+
|
| 281 |
+
if map_module:
|
| 282 |
+
self.map_modules.append(nn.ModuleList([
|
| 283 |
+
MemoryEfficientAttention(
|
| 284 |
+
ich,
|
| 285 |
+
heads = ich // num_head_channels,
|
| 286 |
+
dim_head = num_head_channels
|
| 287 |
+
),
|
| 288 |
+
nn.Linear(time_embed_dim, ich)
|
| 289 |
+
]))
|
| 290 |
+
|
| 291 |
+
if warp_module:
|
| 292 |
+
self.warp_modules.append(nn.ModuleList([
|
| 293 |
+
MemoryEfficientAttention(
|
| 294 |
+
ich,
|
| 295 |
+
heads = ich // num_head_channels,
|
| 296 |
+
dim_head = num_head_channels
|
| 297 |
+
),
|
| 298 |
+
nn.Linear(time_embed_dim, ich)
|
| 299 |
+
]))
|
| 300 |
+
|
| 301 |
+
# self.warp_modules.append(nn.ModuleList([
|
| 302 |
+
# SpatialTransformer(ich, ich//num_head_channels, num_head_channels),
|
| 303 |
+
# nn.Linear(time_embed_dim, ich)
|
| 304 |
+
# ]))
|
| 305 |
+
|
| 306 |
+
if style_modulation:
|
| 307 |
+
self.style_modules.append(nn.ModuleList([
|
| 308 |
+
nn.LayerNorm(ch*2),
|
| 309 |
+
nn.Linear(time_embed_dim, ch*2),
|
| 310 |
+
zero_module(nn.Linear(ch*2, ch*2))
|
| 311 |
+
]))
|
| 312 |
+
|
| 313 |
+
if not discard_final_layers:
|
| 314 |
+
self.out = nn.Sequential(
|
| 315 |
+
normalization(ch),
|
| 316 |
+
nn.SiLU(),
|
| 317 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.conv_fg = zero_module(
|
| 321 |
+
conv_nd(dims, in_channels_fg, model_channels, 3, padding=1)
|
| 322 |
+
) if exists(in_channels_fg) else None
|
| 323 |
+
self.conv_bg = zero_module(
|
| 324 |
+
conv_nd(dims, in_channels_bg, model_channels, 3, padding=1)
|
| 325 |
+
) if exists(in_channels_bg) else None
|
| 326 |
+
|
| 327 |
+
def forward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 328 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
| 329 |
+
emb = self.time_embed(t_emb)
|
| 330 |
+
assert (y is not None) == (
|
| 331 |
+
self.num_classes is not None
|
| 332 |
+
), "must specify y if and only if the model is class-conditional"
|
| 333 |
+
if self.num_classes is not None:
|
| 334 |
+
assert y.shape[0] == x.shape[0]
|
| 335 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 336 |
+
|
| 337 |
+
h = self._forward(x, emb, *args, **kwargs)
|
| 338 |
+
return self.out(h).to(x.dtype)
|
| 339 |
+
|
| 340 |
+
def _forward(
|
| 341 |
+
self,
|
| 342 |
+
x,
|
| 343 |
+
emb,
|
| 344 |
+
control = None,
|
| 345 |
+
context = None,
|
| 346 |
+
mask = None,
|
| 347 |
+
**additional_context
|
| 348 |
+
):
|
| 349 |
+
hs = []
|
| 350 |
+
h = x.to(self.dtype)
|
| 351 |
+
|
| 352 |
+
for module in self.input_blocks:
|
| 353 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 354 |
+
hs.append(h)
|
| 355 |
+
|
| 356 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 357 |
+
|
| 358 |
+
for module in self.output_blocks:
|
| 359 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 360 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 361 |
+
return h
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class DualCondUNetXL(UNetModel):
|
| 365 |
+
def __init__(
|
| 366 |
+
self,
|
| 367 |
+
hint_encoder_index = (0, 3, 6, 8),
|
| 368 |
+
hint_decoder_index = (),
|
| 369 |
+
*args,
|
| 370 |
+
**kwargs
|
| 371 |
+
):
|
| 372 |
+
super().__init__(*args, **kwargs)
|
| 373 |
+
self.hint_encoder_index = hint_encoder_index
|
| 374 |
+
self.hint_decoder_index = hint_decoder_index
|
| 375 |
+
|
| 376 |
+
def _forward(self, x, emb, concat=None, control=None, context=None, mask=None, **additional_context):
|
| 377 |
+
h = x.to(self.dtype)
|
| 378 |
+
hs = []
|
| 379 |
+
|
| 380 |
+
if exists(concat):
|
| 381 |
+
h = torch.cat([h, concat], 1)
|
| 382 |
+
|
| 383 |
+
control_iter = iter(control)
|
| 384 |
+
for idx, module in enumerate(self.input_blocks):
|
| 385 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 386 |
+
|
| 387 |
+
if idx in self.hint_encoder_index:
|
| 388 |
+
h += next(control_iter)
|
| 389 |
+
hs.append(h)
|
| 390 |
+
|
| 391 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 392 |
+
|
| 393 |
+
for idx, module in enumerate(self.output_blocks):
|
| 394 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 395 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 396 |
+
|
| 397 |
+
if idx in self.hint_decoder_index:
|
| 398 |
+
h += next(control_iter)
|
| 399 |
+
|
| 400 |
+
return h
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class ReferenceNet(UNetModel):
|
| 404 |
+
def __init__(self, *args, **kwargs):
|
| 405 |
+
super().__init__(discard_final_layers=True, *args, **kwargs)
|
| 406 |
+
|
| 407 |
+
def forward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 408 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
| 409 |
+
emb = self.time_embed(t_emb)
|
| 410 |
+
|
| 411 |
+
assert (y is not None) == (
|
| 412 |
+
self.num_classes is not None
|
| 413 |
+
), "must specify y if and only if the model is class-conditional"
|
| 414 |
+
if self.num_classes is not None:
|
| 415 |
+
assert y.shape[0] == x.shape[0]
|
| 416 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 417 |
+
self._forward(x, emb, *args, **kwargs)
|
| 418 |
+
|
| 419 |
+
def _forward(self, *args, **kwargs):
|
| 420 |
+
super()._forward(*args, **kwargs)
|
| 421 |
+
return None
|
refnet/modules/unet_old.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
from refnet.util import exists
|
| 6 |
+
from refnet.modules.transformer import (
|
| 7 |
+
SelfTransformerBlock,
|
| 8 |
+
Transformer,
|
| 9 |
+
SpatialTransformer,
|
| 10 |
+
rearrange
|
| 11 |
+
)
|
| 12 |
+
from refnet.ldm.openaimodel import (
|
| 13 |
+
timestep_embedding,
|
| 14 |
+
conv_nd,
|
| 15 |
+
TimestepBlock,
|
| 16 |
+
zero_module,
|
| 17 |
+
ResBlock,
|
| 18 |
+
linear,
|
| 19 |
+
Downsample,
|
| 20 |
+
Upsample,
|
| 21 |
+
normalization,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import xformers
|
| 26 |
+
import xformers.ops
|
| 27 |
+
XFORMERS_IS_AVAILBLE = True
|
| 28 |
+
except:
|
| 29 |
+
XFORMERS_IS_AVAILBLE = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def hack_inference_forward(model):
|
| 33 |
+
model.forward = InferenceForward.__get__(model, model.__class__)
|
| 34 |
+
|
| 35 |
+
def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 36 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 37 |
+
emb = self.time_embed(t_emb).to(self.dtype)
|
| 38 |
+
assert (y is not None) == (
|
| 39 |
+
self.num_classes is not None
|
| 40 |
+
), "must specify y if and only if the model is class-conditional"
|
| 41 |
+
|
| 42 |
+
if self.num_classes is not None:
|
| 43 |
+
assert y.shape[0] == x.shape[0]
|
| 44 |
+
emb = emb + self.label_emb(y.to(emb.device))
|
| 45 |
+
emb = emb.to(self.dtype)
|
| 46 |
+
h = self._forward(x, emb, *args, **kwargs)
|
| 47 |
+
return self.out(h.to(x.dtype))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 52 |
+
"""
|
| 53 |
+
A sequential module that passes timestep embeddings to the children that
|
| 54 |
+
support it as an extra input.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def forward(self, x, emb, context=None, mask=None, **additional_context):
|
| 58 |
+
for layer in self:
|
| 59 |
+
if isinstance(layer, TimestepBlock):
|
| 60 |
+
x = layer(x, emb)
|
| 61 |
+
elif isinstance(layer, Transformer):
|
| 62 |
+
x = layer(x, context, mask, **additional_context)
|
| 63 |
+
else:
|
| 64 |
+
x = layer(x)
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class UNetModel(nn.Module):
|
| 70 |
+
transformers = {
|
| 71 |
+
"vanilla": SpatialTransformer,
|
| 72 |
+
}
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
in_channels,
|
| 76 |
+
model_channels,
|
| 77 |
+
out_channels,
|
| 78 |
+
num_res_blocks,
|
| 79 |
+
attention_resolutions,
|
| 80 |
+
dropout = 0,
|
| 81 |
+
channel_mult = (1, 2, 4, 8),
|
| 82 |
+
conv_resample = True,
|
| 83 |
+
dims = 2,
|
| 84 |
+
num_classes = None,
|
| 85 |
+
use_checkpoint = False,
|
| 86 |
+
num_heads = -1,
|
| 87 |
+
num_head_channels = -1,
|
| 88 |
+
use_scale_shift_norm = False,
|
| 89 |
+
resblock_updown = False,
|
| 90 |
+
use_spatial_transformer = False, # custom transformer support
|
| 91 |
+
transformer_depth = 1, # custom transformer support
|
| 92 |
+
context_dim = None, # custom transformer support
|
| 93 |
+
disable_self_attentions = None,
|
| 94 |
+
disable_cross_attentions = None,
|
| 95 |
+
num_attention_blocks = None,
|
| 96 |
+
use_linear_in_transformer = False,
|
| 97 |
+
adm_in_channels = None,
|
| 98 |
+
transformer_type = "vanilla",
|
| 99 |
+
map_module = False,
|
| 100 |
+
warp_module = False,
|
| 101 |
+
style_modulation = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
if use_spatial_transformer:
|
| 105 |
+
assert exists(context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 106 |
+
assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
|
| 107 |
+
if context_dim is not None:
|
| 108 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 109 |
+
from omegaconf.listconfig import ListConfig
|
| 110 |
+
if type(context_dim) == ListConfig:
|
| 111 |
+
context_dim = list(context_dim)
|
| 112 |
+
|
| 113 |
+
assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
|
| 114 |
+
if isinstance(num_res_blocks, int):
|
| 115 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 116 |
+
else:
|
| 117 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 118 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 119 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 120 |
+
self.num_res_blocks = num_res_blocks
|
| 121 |
+
if disable_self_attentions is not None:
|
| 122 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 123 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 124 |
+
if num_attention_blocks is not None:
|
| 125 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 126 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 127 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 128 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 129 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 130 |
+
f"attention will still not be set.")
|
| 131 |
+
|
| 132 |
+
self.num_classes = num_classes
|
| 133 |
+
self.model_channels = model_channels
|
| 134 |
+
self.dtype = torch.float32
|
| 135 |
+
|
| 136 |
+
if isinstance(transformer_depth, int):
|
| 137 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 138 |
+
transformer_depth_middle = transformer_depth[-1]
|
| 139 |
+
time_embed_dim = model_channels * 4
|
| 140 |
+
resblock = partial(
|
| 141 |
+
ResBlock,
|
| 142 |
+
emb_channels=time_embed_dim,
|
| 143 |
+
dropout=dropout,
|
| 144 |
+
dims=dims,
|
| 145 |
+
use_checkpoint=use_checkpoint,
|
| 146 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 147 |
+
)
|
| 148 |
+
transformer = partial(
|
| 149 |
+
self.transformers[transformer_type],
|
| 150 |
+
context_dim=context_dim,
|
| 151 |
+
use_linear=use_linear_in_transformer,
|
| 152 |
+
use_checkpoint=use_checkpoint,
|
| 153 |
+
disable_self_attn=disable_self_attentions,
|
| 154 |
+
disable_cross_attn=disable_cross_attentions,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.time_embed = nn.Sequential(
|
| 158 |
+
linear(model_channels, time_embed_dim),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
linear(time_embed_dim, time_embed_dim),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if self.num_classes is not None:
|
| 164 |
+
if isinstance(self.num_classes, int):
|
| 165 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 166 |
+
elif self.num_classes == "continuous":
|
| 167 |
+
print("setting up linear c_adm embedding layer")
|
| 168 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 169 |
+
elif self.num_classes == "sequential":
|
| 170 |
+
assert adm_in_channels is not None
|
| 171 |
+
self.label_emb = nn.Sequential(
|
| 172 |
+
nn.Sequential(
|
| 173 |
+
linear(adm_in_channels, time_embed_dim),
|
| 174 |
+
nn.SiLU(),
|
| 175 |
+
linear(time_embed_dim, time_embed_dim),
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError()
|
| 180 |
+
|
| 181 |
+
self.input_blocks = nn.ModuleList([
|
| 182 |
+
TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
| 183 |
+
])
|
| 184 |
+
input_block_chans = [model_channels]
|
| 185 |
+
ch = model_channels
|
| 186 |
+
ds = 1
|
| 187 |
+
for level, mult in enumerate(channel_mult):
|
| 188 |
+
for nr in range(self.num_res_blocks[level]):
|
| 189 |
+
layers = [resblock(ch, out_channels=mult * model_channels)]
|
| 190 |
+
ch = mult * model_channels
|
| 191 |
+
if ds in attention_resolutions:
|
| 192 |
+
if num_head_channels > -1:
|
| 193 |
+
current_num_heads = ch // num_head_channels
|
| 194 |
+
current_head_dim = num_head_channels
|
| 195 |
+
else:
|
| 196 |
+
current_num_heads = num_heads
|
| 197 |
+
current_head_dim = ch // num_heads
|
| 198 |
+
|
| 199 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 200 |
+
layers.append(
|
| 201 |
+
SelfTransformerBlock(ch, current_head_dim)
|
| 202 |
+
if not use_spatial_transformer
|
| 203 |
+
else transformer(
|
| 204 |
+
ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 208 |
+
input_block_chans.append(ch)
|
| 209 |
+
if level != len(channel_mult) - 1:
|
| 210 |
+
out_ch = ch
|
| 211 |
+
self.input_blocks.append(TimestepEmbedSequential(
|
| 212 |
+
resblock(ch, out_channels=out_ch, down=True) if resblock_updown
|
| 213 |
+
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 214 |
+
))
|
| 215 |
+
ch = out_ch
|
| 216 |
+
input_block_chans.append(ch)
|
| 217 |
+
ds *= 2
|
| 218 |
+
|
| 219 |
+
if num_head_channels > -1:
|
| 220 |
+
current_num_heads = ch // num_head_channels
|
| 221 |
+
current_head_dim = num_head_channels
|
| 222 |
+
else:
|
| 223 |
+
current_num_heads = num_heads
|
| 224 |
+
current_head_dim = ch // num_heads
|
| 225 |
+
self.middle_block = TimestepEmbedSequential(
|
| 226 |
+
resblock(ch),
|
| 227 |
+
SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
|
| 228 |
+
else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
|
| 229 |
+
resblock(ch),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.output_blocks = nn.ModuleList([])
|
| 233 |
+
self.map_modules = nn.ModuleList([])
|
| 234 |
+
self.warp_modules = nn.ModuleList([])
|
| 235 |
+
self.style_modules = nn.ModuleList([])
|
| 236 |
+
|
| 237 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
| 238 |
+
for i in range(self.num_res_blocks[level] + 1):
|
| 239 |
+
ich = input_block_chans.pop()
|
| 240 |
+
layers = [resblock(ch + ich, out_channels=model_channels * mult)]
|
| 241 |
+
ch = model_channels * mult
|
| 242 |
+
if ds in attention_resolutions:
|
| 243 |
+
if num_head_channels > -1:
|
| 244 |
+
current_num_heads = ch // num_head_channels
|
| 245 |
+
current_head_dim = num_head_channels
|
| 246 |
+
else:
|
| 247 |
+
current_num_heads = num_heads
|
| 248 |
+
current_head_dim = ch // num_heads
|
| 249 |
+
|
| 250 |
+
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
| 251 |
+
layers.append(
|
| 252 |
+
SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
|
| 253 |
+
else transformer(
|
| 254 |
+
ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
if level and i == self.num_res_blocks[level]:
|
| 258 |
+
out_ch = ch
|
| 259 |
+
layers.append(
|
| 260 |
+
resblock(ch, up=True) if resblock_updown else Upsample(
|
| 261 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
ds //= 2
|
| 265 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| 266 |
+
|
| 267 |
+
if map_module:
|
| 268 |
+
self.map_modules.append(
|
| 269 |
+
SelfTransformerBlock(ich)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if warp_module:
|
| 273 |
+
self.warp_modules.append(
|
| 274 |
+
SelfTransformerBlock(ich)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if style_modulation:
|
| 278 |
+
self.style_modules.append(nn.ModuleList([
|
| 279 |
+
nn.LayerNorm(ch*2),
|
| 280 |
+
nn.Linear(time_embed_dim, ch*2),
|
| 281 |
+
zero_module(nn.Linear(ch*2, ch*2))
|
| 282 |
+
]))
|
| 283 |
+
|
| 284 |
+
self.out = nn.Sequential(
|
| 285 |
+
normalization(ch),
|
| 286 |
+
nn.SiLU(),
|
| 287 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def forward(self, x, timesteps=None, y=None, *args, **kwargs):
|
| 291 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
| 292 |
+
emb = self.time_embed(t_emb)
|
| 293 |
+
assert (y is not None) == (
|
| 294 |
+
self.num_classes is not None
|
| 295 |
+
), "must specify y if and only if the model is class-conditional"
|
| 296 |
+
if self.num_classes is not None:
|
| 297 |
+
assert y.shape[0] == x.shape[0]
|
| 298 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 299 |
+
|
| 300 |
+
h = self._forward(x, emb, *args, **kwargs)
|
| 301 |
+
return self.out(h).to(x.dtype)
|
| 302 |
+
|
| 303 |
+
def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
|
| 304 |
+
hs = []
|
| 305 |
+
h = x.to(self.dtype)
|
| 306 |
+
for module in self.input_blocks:
|
| 307 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 308 |
+
hs.append(h)
|
| 309 |
+
|
| 310 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 311 |
+
|
| 312 |
+
for module in self.output_blocks:
|
| 313 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 314 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 315 |
+
return h
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class DualCondUNet(UNetModel):
|
| 319 |
+
def __init__(self, *args, **kwargs):
|
| 320 |
+
super().__init__(*args, **kwargs)
|
| 321 |
+
self.hint_encoder_index = [0, 3, 6, 9, 11]
|
| 322 |
+
|
| 323 |
+
def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
|
| 324 |
+
h = x.to(self.dtype)
|
| 325 |
+
hs = []
|
| 326 |
+
|
| 327 |
+
control_iter = iter(control)
|
| 328 |
+
for idx, module in enumerate(self.input_blocks):
|
| 329 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 330 |
+
|
| 331 |
+
if idx in self.hint_encoder_index:
|
| 332 |
+
h += next(control_iter)
|
| 333 |
+
hs.append(h)
|
| 334 |
+
|
| 335 |
+
h = self.middle_block(h, emb, context, mask, **additional_context)
|
| 336 |
+
|
| 337 |
+
for idx, module in enumerate(self.output_blocks):
|
| 338 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 339 |
+
h = module(h, emb, context, mask, **additional_context)
|
| 340 |
+
|
| 341 |
+
return h
|
| 342 |
+
|
| 343 |
+
class OldUnet(UNetModel):
|
| 344 |
+
def __init__(self, c_channels, model_channels, channel_mult, *args, **kwargs):
|
| 345 |
+
super().__init__(channel_mult=channel_mult, model_channels=model_channels, *args, **kwargs)
|
| 346 |
+
"""
|
| 347 |
+
Semantic condition input blocks, implementation from ControlNet.
|
| 348 |
+
Paper: Adding Conditional Control to Text-to-Image Diffusion Models
|
| 349 |
+
Authors: Lvmin Zhang, Anyi Rao, and Maneesh Agrawala
|
| 350 |
+
Code link: https://github.com/lllyasviel/ControlNet
|
| 351 |
+
"""
|
| 352 |
+
from refnet.modules.encoder import SimpleEncoder, MultiEncoder
|
| 353 |
+
# self.semantic_input_blocks = SimpleEncoder(c_channels, model_channels)
|
| 354 |
+
self.semantic_input_blocks = MultiEncoder(c_channels, model_channels, channel_mult)
|
| 355 |
+
self.hint_encoder_index = [0, 3, 6, 9, 11]
|
| 356 |
+
|
| 357 |
+
def forward(self, x, timesteps=None, control=None, context=None, y=None, **kwargs):
|
| 358 |
+
concat = control[0].to(self.dtype)
|
| 359 |
+
context = context.to(self.dtype)
|
| 360 |
+
|
| 361 |
+
assert (y is not None) == (
|
| 362 |
+
self.num_classes is not None
|
| 363 |
+
), "must specify y if and only if the model is class-conditional"
|
| 364 |
+
hs = []
|
| 365 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 366 |
+
emb = self.time_embed(t_emb).to(self.dtype)
|
| 367 |
+
|
| 368 |
+
if self.num_classes is not None:
|
| 369 |
+
assert y.shape[0] == x.shape[0]
|
| 370 |
+
emb = emb + self.label_emb(y)
|
| 371 |
+
|
| 372 |
+
h = x.to(self.dtype)
|
| 373 |
+
hints = self.semantic_input_blocks(concat, emb, context)
|
| 374 |
+
|
| 375 |
+
for idx, module in enumerate(self.input_blocks):
|
| 376 |
+
h = module(h, emb, context)
|
| 377 |
+
if idx in self.hint_encoder_index:
|
| 378 |
+
h += hints.pop(0)
|
| 379 |
+
|
| 380 |
+
hs.append(h)
|
| 381 |
+
|
| 382 |
+
h = self.middle_block(h, emb, context)
|
| 383 |
+
|
| 384 |
+
for module in self.output_blocks:
|
| 385 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 386 |
+
h = module(h, emb, context)
|
| 387 |
+
h = h.to(x.dtype)
|
| 388 |
+
return self.out(h)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class UNetEncoder(nn.Module):
|
| 392 |
+
transformers = {
|
| 393 |
+
"vanilla": SpatialTransformer,
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
def __init__(
|
| 397 |
+
self,
|
| 398 |
+
in_channels,
|
| 399 |
+
model_channels,
|
| 400 |
+
num_res_blocks,
|
| 401 |
+
attention_resolutions,
|
| 402 |
+
dropout = 0,
|
| 403 |
+
channel_mult = (1, 2, 4, 8),
|
| 404 |
+
conv_resample = True,
|
| 405 |
+
dims = 2,
|
| 406 |
+
num_classes = None,
|
| 407 |
+
use_checkpoint = False,
|
| 408 |
+
num_heads = -1,
|
| 409 |
+
num_head_channels = -1,
|
| 410 |
+
use_scale_shift_norm = False,
|
| 411 |
+
resblock_updown = False,
|
| 412 |
+
use_spatial_transformer = False, # custom transformer support
|
| 413 |
+
transformer_depth = 1, # custom transformer support
|
| 414 |
+
context_dim = None, # custom transformer support
|
| 415 |
+
disable_self_attentions = None,
|
| 416 |
+
disable_cross_attentions = None,
|
| 417 |
+
num_attention_blocks = None,
|
| 418 |
+
use_linear_in_transformer = False,
|
| 419 |
+
adm_in_channels = None,
|
| 420 |
+
transformer_type = "vanilla",
|
| 421 |
+
style_modulation = False,
|
| 422 |
+
):
|
| 423 |
+
super().__init__()
|
| 424 |
+
if use_spatial_transformer:
|
| 425 |
+
assert exists(
|
| 426 |
+
context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 427 |
+
assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
|
| 428 |
+
from omegaconf.listconfig import ListConfig
|
| 429 |
+
if type(context_dim) == ListConfig:
|
| 430 |
+
context_dim = list(context_dim)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if num_heads == -1:
|
| 434 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 435 |
+
|
| 436 |
+
if num_head_channels == -1:
|
| 437 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 438 |
+
self.in_channels = in_channels
|
| 439 |
+
self.model_channels = model_channels
|
| 440 |
+
if isinstance(num_res_blocks, int):
|
| 441 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 442 |
+
else:
|
| 443 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 444 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 445 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 446 |
+
self.num_res_blocks = num_res_blocks
|
| 447 |
+
if disable_self_attentions is not None:
|
| 448 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 449 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 450 |
+
if num_attention_blocks is not None:
|
| 451 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 452 |
+
assert all(
|
| 453 |
+
map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 454 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 455 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 456 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 457 |
+
f"attention will still not be set.")
|
| 458 |
+
|
| 459 |
+
self.attention_resolutions = attention_resolutions
|
| 460 |
+
self.dropout = dropout
|
| 461 |
+
self.channel_mult = channel_mult
|
| 462 |
+
self.conv_resample = conv_resample
|
| 463 |
+
self.num_classes = num_classes
|
| 464 |
+
self.use_checkpoint = use_checkpoint
|
| 465 |
+
self.dtype = torch.float32
|
| 466 |
+
self.num_heads = num_heads
|
| 467 |
+
self.num_head_channels = num_head_channels
|
| 468 |
+
self.style_modulation = style_modulation
|
| 469 |
+
|
| 470 |
+
if isinstance(transformer_depth, int):
|
| 471 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 472 |
+
|
| 473 |
+
time_embed_dim = model_channels * 4
|
| 474 |
+
|
| 475 |
+
resblock = partial(
|
| 476 |
+
ResBlock,
|
| 477 |
+
emb_channels=time_embed_dim,
|
| 478 |
+
dropout=dropout,
|
| 479 |
+
dims=dims,
|
| 480 |
+
use_checkpoint=use_checkpoint,
|
| 481 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
transformer = partial(
|
| 485 |
+
self.transformers[transformer_type],
|
| 486 |
+
context_dim=context_dim,
|
| 487 |
+
use_linear=use_linear_in_transformer,
|
| 488 |
+
use_checkpoint=use_checkpoint,
|
| 489 |
+
disable_self_attn=disable_self_attentions,
|
| 490 |
+
disable_cross_attn=disable_cross_attentions,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
|
| 494 |
+
|
| 495 |
+
self.time_embed = nn.Sequential(
|
| 496 |
+
linear(model_channels, time_embed_dim),
|
| 497 |
+
nn.SiLU(),
|
| 498 |
+
linear(time_embed_dim, time_embed_dim),
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if self.num_classes is not None:
|
| 502 |
+
if isinstance(self.num_classes, int):
|
| 503 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 504 |
+
elif self.num_classes == "continuous":
|
| 505 |
+
print("setting up linear c_adm embedding layer")
|
| 506 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 507 |
+
elif self.num_classes == "sequential":
|
| 508 |
+
assert adm_in_channels is not None
|
| 509 |
+
self.label_emb = nn.Sequential(
|
| 510 |
+
nn.Sequential(
|
| 511 |
+
linear(adm_in_channels, time_embed_dim),
|
| 512 |
+
nn.SiLU(),
|
| 513 |
+
linear(time_embed_dim, time_embed_dim),
|
| 514 |
+
)
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
raise ValueError()
|
| 518 |
+
|
| 519 |
+
self.input_blocks = nn.ModuleList(
|
| 520 |
+
[
|
| 521 |
+
TimestepEmbedSequential(
|
| 522 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 523 |
+
)
|
| 524 |
+
]
|
| 525 |
+
)
|
| 526 |
+
self.zero_layers = nn.ModuleList([zero_module(
|
| 527 |
+
nn.Linear(model_channels, model_channels * 2) if style_modulation else
|
| 528 |
+
zero_conv(model_channels, model_channels)
|
| 529 |
+
)])
|
| 530 |
+
|
| 531 |
+
ch = model_channels
|
| 532 |
+
ds = 1
|
| 533 |
+
for level, mult in enumerate(channel_mult):
|
| 534 |
+
for nr in range(self.num_res_blocks[level]):
|
| 535 |
+
layers = [resblock(ch, out_channels=mult * model_channels)]
|
| 536 |
+
ch = mult * model_channels
|
| 537 |
+
if ds in attention_resolutions:
|
| 538 |
+
num_heads = ch // num_head_channels
|
| 539 |
+
|
| 540 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 541 |
+
layers.append(
|
| 542 |
+
SelfTransformerBlock(ch, num_head_channels)
|
| 543 |
+
if not use_spatial_transformer
|
| 544 |
+
else transformer(
|
| 545 |
+
ch, num_heads, num_head_channels, depth=transformer_depth[level]
|
| 546 |
+
)
|
| 547 |
+
)
|
| 548 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 549 |
+
self.zero_layers.append(zero_module(
|
| 550 |
+
nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
|
| 551 |
+
))
|
| 552 |
+
|
| 553 |
+
if level != len(channel_mult) - 1:
|
| 554 |
+
out_ch = ch
|
| 555 |
+
self.input_blocks.append(TimestepEmbedSequential(
|
| 556 |
+
resblock(ch, out_channels=mult * model_channels, down=True) if resblock_updown else Downsample(
|
| 557 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 558 |
+
)
|
| 559 |
+
))
|
| 560 |
+
self.zero_layers.append(zero_module(
|
| 561 |
+
nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
|
| 562 |
+
zero_conv(out_ch, out_ch)
|
| 563 |
+
))
|
| 564 |
+
ch = out_ch
|
| 565 |
+
ds *= 2
|
| 566 |
+
|
| 567 |
+
def forward(self, x, timesteps = None, y = None, *args, **kwargs):
|
| 568 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
| 569 |
+
emb = self.time_embed(t_emb)
|
| 570 |
+
|
| 571 |
+
assert (y is not None) == (
|
| 572 |
+
self.num_classes is not None
|
| 573 |
+
), "must specify y if and only if the model is class-conditional"
|
| 574 |
+
if self.num_classes is not None:
|
| 575 |
+
assert y.shape[0] == x.shape[0]
|
| 576 |
+
emb = emb + self.label_emb(y.to(self.dtype))
|
| 577 |
+
|
| 578 |
+
hs = self._forward(x, emb, *args, **kwargs)
|
| 579 |
+
return hs
|
| 580 |
+
|
| 581 |
+
def _forward(self, x, emb, context = None, **additional_context):
|
| 582 |
+
hints = []
|
| 583 |
+
h = x.to(self.dtype)
|
| 584 |
+
|
| 585 |
+
for zero_layer, module in zip(self.zero_layers, self.input_blocks):
|
| 586 |
+
h = module(h, emb, context, **additional_context)
|
| 587 |
+
|
| 588 |
+
if self.style_modulation:
|
| 589 |
+
hint = zero_layer(h.mean(dim=[2, 3]))
|
| 590 |
+
else:
|
| 591 |
+
hint = zero_layer(h)
|
| 592 |
+
hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
|
| 593 |
+
hints.append(hint)
|
| 594 |
+
|
| 595 |
+
hints.reverse()
|
| 596 |
+
return hints
|
refnet/sampling/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .denoiser import CFGDenoiser, DiffuserDenoiser
|
| 2 |
+
from .hook import UnetHook, torch_dfs
|
| 3 |
+
from .tps_transformation import tps_warp
|
| 4 |
+
from .sampler import KDiffusionSampler, kdiffusion_sampler_list
|
| 5 |
+
from .scheduler import get_noise_schedulers
|
| 6 |
+
|
| 7 |
+
def get_sampler_list():
|
| 8 |
+
sampler_list = [
|
| 9 |
+
"diffuser_" + k for k in DiffuserDenoiser.scheduler_types.keys()
|
| 10 |
+
] + kdiffusion_sampler_list()
|
| 11 |
+
return sorted(sampler_list)
|
refnet/sampling/denoiser.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from typing import Union, Optional
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
| 10 |
+
from diffusers.schedulers import (
|
| 11 |
+
DDIMScheduler,
|
| 12 |
+
DPMSolverMultistepScheduler,
|
| 13 |
+
PNDMScheduler,
|
| 14 |
+
LMSDiscreteScheduler,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def exists(v):
|
| 18 |
+
return v is not None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CFGDenoiser(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
| 25 |
+
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
| 26 |
+
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
| 27 |
+
negative prompt.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, model, device):
|
| 31 |
+
super().__init__()
|
| 32 |
+
denoiser = CompVisDenoiser if model.parameterization == "eps" else CompVisVDenoiser
|
| 33 |
+
self.model_wrap = denoiser(model, device=device)
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def inner_model(self):
|
| 37 |
+
return self.model_wrap
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
x,
|
| 42 |
+
sigma,
|
| 43 |
+
cond: dict,
|
| 44 |
+
cond_scale: Union[float, list[float]]
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Simplify k-diffusion sampler for sketch colorizaiton.
|
| 48 |
+
Available for reference CFG / sketch CFG or Dual CFG
|
| 49 |
+
"""
|
| 50 |
+
if not isinstance(cond_scale, list):
|
| 51 |
+
if cond_scale > 1.:
|
| 52 |
+
repeats = 2
|
| 53 |
+
else:
|
| 54 |
+
return self.inner_model(x, sigma, cond=cond)
|
| 55 |
+
else:
|
| 56 |
+
repeats = 3
|
| 57 |
+
|
| 58 |
+
x_in = torch.cat([x] * repeats)
|
| 59 |
+
sigma_in = torch.cat([sigma] * repeats)
|
| 60 |
+
x_out = self.inner_model(x_in, sigma_in, cond=cond).chunk(repeats)
|
| 61 |
+
|
| 62 |
+
if repeats == 2:
|
| 63 |
+
x_cond, x_uncond = x_out[:]
|
| 64 |
+
return x_uncond + (x_cond - x_uncond) * cond_scale
|
| 65 |
+
else:
|
| 66 |
+
x_cond, x_uncond_0, x_uncond_1 = x_out[:]
|
| 67 |
+
return (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
|
| 68 |
+
x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
scheduler_config_path = "configs/scheduler_cfgs"
|
| 74 |
+
class DiffuserDenoiser:
|
| 75 |
+
scheduler_types = {
|
| 76 |
+
"ddim": DDIMScheduler,
|
| 77 |
+
"dpm": DPMSolverMultistepScheduler,
|
| 78 |
+
"dpm_sde": DPMSolverMultistepScheduler,
|
| 79 |
+
"pndm": PNDMScheduler,
|
| 80 |
+
"lms": LMSDiscreteScheduler
|
| 81 |
+
}
|
| 82 |
+
def __init__(self, scheduler_type, prediction_type, use_karras=False):
|
| 83 |
+
scheduler_type = scheduler_type.replace("diffuser_", "")
|
| 84 |
+
assert scheduler_type in self.scheduler_types.keys(), "Selected scheduler is not implemented"
|
| 85 |
+
scheduler = self.scheduler_types[scheduler_type]
|
| 86 |
+
scheduler_config = OmegaConf.load(osp.abspath(osp.join(scheduler_config_path, scheduler_type + ".yaml")))
|
| 87 |
+
if "use_karras_sigmas" in set(inspect.signature(scheduler).parameters.keys()):
|
| 88 |
+
scheduler_config.use_karras_sigmas = use_karras
|
| 89 |
+
self.scheduler = scheduler(prediction_type=prediction_type, **scheduler_config)
|
| 90 |
+
|
| 91 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 92 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 93 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 94 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 95 |
+
# and should be between [0, 1]
|
| 96 |
+
|
| 97 |
+
accepts_eta = "eta" in set(
|
| 98 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
| 99 |
+
)
|
| 100 |
+
extra_step_kwargs = {}
|
| 101 |
+
if accepts_eta:
|
| 102 |
+
extra_step_kwargs["eta"] = eta
|
| 103 |
+
|
| 104 |
+
# check if the scheduler accepts generator
|
| 105 |
+
accepts_generator = "generator" in set(
|
| 106 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
| 107 |
+
)
|
| 108 |
+
if accepts_generator:
|
| 109 |
+
extra_step_kwargs["generator"] = generator
|
| 110 |
+
return extra_step_kwargs
|
| 111 |
+
|
| 112 |
+
def __call__(
|
| 113 |
+
self,
|
| 114 |
+
x,
|
| 115 |
+
cond,
|
| 116 |
+
cond_scale,
|
| 117 |
+
unet,
|
| 118 |
+
timesteps,
|
| 119 |
+
generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
|
| 120 |
+
eta: float = 0.0,
|
| 121 |
+
device: str = "cuda"
|
| 122 |
+
):
|
| 123 |
+
self.scheduler.set_timesteps(timesteps, device=device)
|
| 124 |
+
timesteps = self.scheduler.timesteps
|
| 125 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 126 |
+
|
| 127 |
+
x_start = x
|
| 128 |
+
x = x * self.scheduler.init_noise_sigma
|
| 129 |
+
inpaint_latents = cond.pop("inpaint_bg", None)
|
| 130 |
+
|
| 131 |
+
if exists(inpaint_latents):
|
| 132 |
+
mask = cond.get("mask", None)
|
| 133 |
+
threshold = cond.pop("threshold", 0.5)
|
| 134 |
+
inpaint_latents = inpaint_latents[0]
|
| 135 |
+
assert exists(mask)
|
| 136 |
+
mask = mask[0]
|
| 137 |
+
mask = torch.where(mask > threshold, torch.ones_like(mask), torch.zeros_like(mask))
|
| 138 |
+
|
| 139 |
+
for i, t in enumerate(tqdm(timesteps)):
|
| 140 |
+
x_t = self.scheduler.scale_model_input(x, t)
|
| 141 |
+
|
| 142 |
+
if not isinstance(cond_scale, list):
|
| 143 |
+
if cond_scale > 1.:
|
| 144 |
+
repeats = 2
|
| 145 |
+
else:
|
| 146 |
+
repeats = 1
|
| 147 |
+
else:
|
| 148 |
+
repeats = 3
|
| 149 |
+
|
| 150 |
+
x_in = torch.cat([x_t] * repeats)
|
| 151 |
+
x_out = unet.apply_model(
|
| 152 |
+
x_in,
|
| 153 |
+
t[None].expand(x_in.shape[0]),
|
| 154 |
+
cond=cond
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if repeats == 1:
|
| 158 |
+
pred = x_out
|
| 159 |
+
|
| 160 |
+
elif repeats == 2:
|
| 161 |
+
x_cond, x_uncond = x_out.chunk(2)
|
| 162 |
+
pred = x_uncond + (x_cond - x_uncond) * cond_scale
|
| 163 |
+
|
| 164 |
+
else:
|
| 165 |
+
x_cond, x_uncond_0, x_uncond_1 = x_out.chunk(3)
|
| 166 |
+
pred = (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
|
| 167 |
+
x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
|
| 168 |
+
|
| 169 |
+
x = self.scheduler.step(
|
| 170 |
+
pred, t, x, **extra_step_kwargs, return_dict=False
|
| 171 |
+
)[0]
|
| 172 |
+
|
| 173 |
+
if exists(inpaint_latents) and exists(mask) and i < len(timesteps) - 1:
|
| 174 |
+
noise_timestep = timesteps[i + 1]
|
| 175 |
+
init_latents_proper = inpaint_latents
|
| 176 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 177 |
+
init_latents_proper, x_start, torch.tensor([noise_timestep])
|
| 178 |
+
)
|
| 179 |
+
x = (1 - mask) * init_latents_proper + mask * x
|
| 180 |
+
|
| 181 |
+
return x
|
refnet/sampling/hook.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
|
| 5 |
+
from refnet.util import checkpoint_wrapper
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
This implementation refers to Multi-ControlNet, thanks for the authors
|
| 9 |
+
Paper: Adding Conditional Control to Text-to-Image Diffusion Models
|
| 10 |
+
Link: https://github.com/Mikubill/sd-webui-controlnet
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def exists(v):
|
| 14 |
+
return v is not None
|
| 15 |
+
|
| 16 |
+
def torch_dfs(model: nn.Module):
|
| 17 |
+
result = [model]
|
| 18 |
+
for child in model.children():
|
| 19 |
+
result += torch_dfs(child)
|
| 20 |
+
return result
|
| 21 |
+
|
| 22 |
+
class AutoMachine():
|
| 23 |
+
Read = "read"
|
| 24 |
+
Write = "write"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
This class controls the attentions of reference unet and denoising unet
|
| 29 |
+
"""
|
| 30 |
+
class ReferenceAttentionControl:
|
| 31 |
+
writer_modules = []
|
| 32 |
+
reader_modules = []
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
reader_module,
|
| 36 |
+
writer_module,
|
| 37 |
+
time_embed_ch = 0,
|
| 38 |
+
only_decoder = True,
|
| 39 |
+
*args,
|
| 40 |
+
**kwargs
|
| 41 |
+
):
|
| 42 |
+
self.time_embed_ch = time_embed_ch
|
| 43 |
+
self.trainable_layers = []
|
| 44 |
+
self.only_decoder = only_decoder
|
| 45 |
+
self.hooked = False
|
| 46 |
+
|
| 47 |
+
self.register("read", reader_module)
|
| 48 |
+
self.register("write", writer_module)
|
| 49 |
+
|
| 50 |
+
if time_embed_ch > 0:
|
| 51 |
+
self.insert_time_emb_proj(reader_module)
|
| 52 |
+
|
| 53 |
+
def insert_time_emb_proj(self, unet):
|
| 54 |
+
for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
|
| 55 |
+
if isinstance(module, BasicTransformerBlock):
|
| 56 |
+
module.time_proj = nn.Linear(self.time_embed_ch, module.dim)
|
| 57 |
+
self.trainable_layers.append(module.time_proj)
|
| 58 |
+
|
| 59 |
+
def register(self, mode, unet):
|
| 60 |
+
@checkpoint_wrapper
|
| 61 |
+
def transformer_forward_write(self, x, context=None, mask=None, emb=None, **kwargs):
|
| 62 |
+
x_in = self.norm1(x)
|
| 63 |
+
x = self.attn1(x_in) + x
|
| 64 |
+
|
| 65 |
+
if not self.disable_cross_attn:
|
| 66 |
+
x = self.attn2(self.norm2(x), context) + x
|
| 67 |
+
x = self.ff(self.norm3(x)) + x
|
| 68 |
+
|
| 69 |
+
self.bank = x_in
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
@checkpoint_wrapper
|
| 73 |
+
def transformer_forward_read(self, x, context=None, mask=None, emb=None, **kwargs):
|
| 74 |
+
if exists(self.bank):
|
| 75 |
+
bank = self.bank
|
| 76 |
+
if bank.shape[0] != x.shape[0]:
|
| 77 |
+
bank = bank.repeat(x.shape[0], 1, 1)
|
| 78 |
+
if hasattr(self, "time_proj"):
|
| 79 |
+
bank = bank + self.time_proj(emb).unsqueeze(1)
|
| 80 |
+
x_in = self.norm1(x)
|
| 81 |
+
|
| 82 |
+
x = self.attn1(
|
| 83 |
+
x = x_in,
|
| 84 |
+
context = torch.cat([x_in, bank], 1),
|
| 85 |
+
mask = mask,
|
| 86 |
+
scale_factor = self.scale_factor,
|
| 87 |
+
**kwargs
|
| 88 |
+
) + x
|
| 89 |
+
|
| 90 |
+
x = self.attn2(
|
| 91 |
+
x = self.norm2(x),
|
| 92 |
+
context = context,
|
| 93 |
+
mask = mask,
|
| 94 |
+
scale = self.reference_scale,
|
| 95 |
+
scale_factor = self.scale_factor
|
| 96 |
+
) + x
|
| 97 |
+
|
| 98 |
+
x = self.ff(self.norm3(x)) + x
|
| 99 |
+
else:
|
| 100 |
+
x = self.original_forward(x, context, mask, emb)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
assert mode in ["write", "read"]
|
| 104 |
+
|
| 105 |
+
if mode == "read":
|
| 106 |
+
self.hooked = True
|
| 107 |
+
for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
|
| 108 |
+
if isinstance(module, BasicTransformerBlock):
|
| 109 |
+
if mode == "write":
|
| 110 |
+
module.original_forward = module.forward
|
| 111 |
+
module.forward = transformer_forward_write.__get__(module, BasicTransformerBlock)
|
| 112 |
+
self.writer_modules.append(module)
|
| 113 |
+
else:
|
| 114 |
+
if not isinstance(module, SelfInjectedTransformerBlock):
|
| 115 |
+
print(f"Hooking transformer block {module.__class__.__name__} for read mode")
|
| 116 |
+
module.original_forward = module.forward
|
| 117 |
+
module.forward = transformer_forward_read.__get__(module, BasicTransformerBlock)
|
| 118 |
+
self.reader_modules.append(module)
|
| 119 |
+
|
| 120 |
+
def update(self):
|
| 121 |
+
for idx in range(len(self.writer_modules)):
|
| 122 |
+
self.reader_modules[idx].bank = self.writer_modules[idx].bank
|
| 123 |
+
|
| 124 |
+
def restore(self):
|
| 125 |
+
for idx in range(len(self.writer_modules)):
|
| 126 |
+
self.writer_modules[idx].forward = self.writer_modules[idx].original_forward
|
| 127 |
+
self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
|
| 128 |
+
self.reader_modules[idx].bank = None
|
| 129 |
+
self.hooked = False
|
| 130 |
+
|
| 131 |
+
def clean(self):
|
| 132 |
+
for idx in range(len(self.reader_modules)):
|
| 133 |
+
self.reader_modules[idx].bank = None
|
| 134 |
+
for idx in range(len(self.writer_modules)):
|
| 135 |
+
self.writer_modules[idx].bank = None
|
| 136 |
+
self.hooked = False
|
| 137 |
+
|
| 138 |
+
def reader_restore(self):
|
| 139 |
+
for idx in range(len(self.reader_modules)):
|
| 140 |
+
self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
|
| 141 |
+
self.reader_modules[idx].bank = None
|
| 142 |
+
self.hooked = False
|
| 143 |
+
|
| 144 |
+
def get_trainable_layers(self):
|
| 145 |
+
return self.trainable_layers
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
"""
|
| 149 |
+
This class is for self-injection inside the denoising unet
|
| 150 |
+
"""
|
| 151 |
+
class UnetHook:
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.attention_auto_machine = AutoMachine.Read
|
| 155 |
+
|
| 156 |
+
def enhance_reference(
|
| 157 |
+
self,
|
| 158 |
+
model,
|
| 159 |
+
ldm,
|
| 160 |
+
bs,
|
| 161 |
+
s,
|
| 162 |
+
r,
|
| 163 |
+
style_cfg=0.5,
|
| 164 |
+
control_cfg=0,
|
| 165 |
+
gr_indice=None,
|
| 166 |
+
injection=False,
|
| 167 |
+
start_step=0,
|
| 168 |
+
):
|
| 169 |
+
def forward(self, x, t, control, context, **kwargs):
|
| 170 |
+
if 1 - t[0] / (ldm.num_timesteps - 1) >= outer.start_step:
|
| 171 |
+
# Write
|
| 172 |
+
outer.attention_auto_machine = AutoMachine.Write
|
| 173 |
+
|
| 174 |
+
rx = ldm.add_noise(outer.r.cpu(), torch.round(t.float()).long().cpu()).cuda().to(x.dtype)
|
| 175 |
+
self.original_forward(rx, t, control=outer.s, context=context, **kwargs)
|
| 176 |
+
|
| 177 |
+
# Read
|
| 178 |
+
outer.attention_auto_machine = AutoMachine.Read
|
| 179 |
+
return self.original_forward(x, t, control=control, context=context, **kwargs)
|
| 180 |
+
|
| 181 |
+
def hacked_basic_transformer_inner_forward(self, x, context=None, mask=None, emb=None, **kwargs):
|
| 182 |
+
x_norm1 = self.norm1(x)
|
| 183 |
+
self_attn1 = None
|
| 184 |
+
if self.disable_self_attn:
|
| 185 |
+
# Do not use self-attention
|
| 186 |
+
self_attn1 = self.attn1(x_norm1, context=context, **kwargs)
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
# Use self-attention
|
| 190 |
+
self_attention_context = x_norm1
|
| 191 |
+
if outer.attention_auto_machine == AutoMachine.Write:
|
| 192 |
+
self.bank.append(self_attention_context.detach().clone())
|
| 193 |
+
self.style_cfgs.append(outer.current_style_fidelity)
|
| 194 |
+
if outer.attention_auto_machine == AutoMachine.Read:
|
| 195 |
+
if len(self.bank) > 0:
|
| 196 |
+
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
|
| 197 |
+
self_attn1_uc = self.attn1(
|
| 198 |
+
x_norm1,
|
| 199 |
+
context=torch.cat([self_attention_context] + self.bank, dim=1),
|
| 200 |
+
**kwargs
|
| 201 |
+
)
|
| 202 |
+
self_attn1_c = self_attn1_uc.clone()
|
| 203 |
+
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
|
| 204 |
+
self_attn1_c[outer.current_uc_indices] = self.attn1(
|
| 205 |
+
x_norm1[outer.current_uc_indices],
|
| 206 |
+
context=self_attention_context[outer.current_uc_indices],
|
| 207 |
+
**kwargs
|
| 208 |
+
)
|
| 209 |
+
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
|
| 210 |
+
self.bank = []
|
| 211 |
+
self.style_cfgs = []
|
| 212 |
+
if self_attn1 is None:
|
| 213 |
+
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
|
| 214 |
+
|
| 215 |
+
x = self_attn1.to(x.dtype) + x
|
| 216 |
+
x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor, **kwargs) + x
|
| 217 |
+
x = self.ff(self.norm3(x)) + x
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
self.s = [s.repeat(bs, 1, 1, 1) * control_cfg for s in ldm.control_encoder(s)]
|
| 221 |
+
self.r = r
|
| 222 |
+
self.injection = injection
|
| 223 |
+
self.start_step = start_step
|
| 224 |
+
self.current_uc_indices = gr_indice
|
| 225 |
+
self.current_style_fidelity = style_cfg
|
| 226 |
+
|
| 227 |
+
outer = self
|
| 228 |
+
model = model.diffusion_model
|
| 229 |
+
model.original_forward = model.forward
|
| 230 |
+
# TODO: change the class name to target
|
| 231 |
+
model.forward = forward.__get__(model, model.__class__)
|
| 232 |
+
all_modules = torch_dfs(model)
|
| 233 |
+
|
| 234 |
+
for module in all_modules:
|
| 235 |
+
if isinstance(module, BasicTransformerBlock):
|
| 236 |
+
module._unet_hook_original_forward = module.forward
|
| 237 |
+
module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
| 238 |
+
module.bank = []
|
| 239 |
+
module.style_cfgs = []
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def restore(self, model):
|
| 243 |
+
model = model.diffusion_model
|
| 244 |
+
if hasattr(model, "original_forward"):
|
| 245 |
+
model.forward = model.original_forward
|
| 246 |
+
del model.original_forward
|
| 247 |
+
|
| 248 |
+
all_modules = torch_dfs(model)
|
| 249 |
+
for module in all_modules:
|
| 250 |
+
if isinstance(module, BasicTransformerBlock):
|
| 251 |
+
if hasattr(module, "_unet_hook_original_forward"):
|
| 252 |
+
module.forward = module._unet_hook_original_forward
|
| 253 |
+
del module._unet_hook_original_forward
|
| 254 |
+
if hasattr(module, "bank"):
|
| 255 |
+
module.bank = None
|
| 256 |
+
if hasattr(module, "style_cfgs"):
|
| 257 |
+
del module.style_cfgs
|
refnet/sampling/manipulation.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]):
|
| 8 |
+
"""
|
| 9 |
+
The shape of input scales tensor should be (b, n, 1)
|
| 10 |
+
"""
|
| 11 |
+
assert len(s.shape) == 3, len(thresholds) == 4
|
| 12 |
+
maxm = s.max(dim=1, keepdim=True).values
|
| 13 |
+
minm = s.min(dim=1, keepdim=True).values
|
| 14 |
+
d = maxm - minm
|
| 15 |
+
|
| 16 |
+
maxmin = (s - minm) / d
|
| 17 |
+
|
| 18 |
+
adjust_scale = torch.where(maxmin <= thresholds[0],
|
| 19 |
+
-dscale * ratio,
|
| 20 |
+
-dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0]))
|
| 21 |
+
adjust_scale = torch.where(maxmin > thresholds[1],
|
| 22 |
+
0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]),
|
| 23 |
+
adjust_scale)
|
| 24 |
+
adjust_scale = torch.where(maxmin > thresholds[2],
|
| 25 |
+
0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]),
|
| 26 |
+
adjust_scale)
|
| 27 |
+
adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale)
|
| 28 |
+
return adjust_scale
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]):
|
| 32 |
+
# print(f"target:{t}, anchor:{a}")
|
| 33 |
+
cls_token = v[:, 0].unsqueeze(1)
|
| 34 |
+
v = v[:, 1:]
|
| 35 |
+
|
| 36 |
+
cur_target_scale = clip.calculate_scale(cls_token, t)
|
| 37 |
+
# control_scale = clip.calculate_scale(cls_token, c)
|
| 38 |
+
# print(f"current global target scale: {cur_target_scale},",
|
| 39 |
+
# f" global control scale: {control_scale}")
|
| 40 |
+
|
| 41 |
+
if a is not None and a != "none":
|
| 42 |
+
a = [a] * v.shape[0]
|
| 43 |
+
a = clip.encode_text(a)
|
| 44 |
+
anchor_scale = clip.calculate_scale(cls_token, a)
|
| 45 |
+
dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale
|
| 46 |
+
# print(f"global anchor scale: {anchor_scale}")
|
| 47 |
+
|
| 48 |
+
c_map = clip.calculate_scale(v, c)
|
| 49 |
+
a_map = clip.calculate_scale(v, a)
|
| 50 |
+
pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
|
| 51 |
+
base = 1 if enhance else 0
|
| 52 |
+
v = v + (pwm + base * a_map) * (t - a)
|
| 53 |
+
else:
|
| 54 |
+
dscale = target_scale - cur_target_scale
|
| 55 |
+
c_map = clip.calculate_scale(v, c)
|
| 56 |
+
pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
|
| 57 |
+
v = v + pwm * t
|
| 58 |
+
v = torch.cat([cls_token, v], dim=1)
|
| 59 |
+
return v
|
| 60 |
+
|
| 61 |
+
def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]):
|
| 62 |
+
"""
|
| 63 |
+
v: visual tokens in shape (b, n, c)
|
| 64 |
+
target: target text embeddings in shape (b, 1 ,c)
|
| 65 |
+
control: control text embeddings in shape (b, 1, c)
|
| 66 |
+
"""
|
| 67 |
+
controls, targets = clip.encode_text(controls + targets).chunk(2)
|
| 68 |
+
for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list):
|
| 69 |
+
v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds)
|
| 70 |
+
return v
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False):
|
| 74 |
+
if a is not None and a != "none":
|
| 75 |
+
a = [a] * v.shape[0]
|
| 76 |
+
a = clip.encode_text(a)
|
| 77 |
+
if enhance:
|
| 78 |
+
s_a = clip.calculate_scale(v, a)
|
| 79 |
+
v = v - s_a * a
|
| 80 |
+
else:
|
| 81 |
+
v = v + target_scale * (t - a)
|
| 82 |
+
return v
|
| 83 |
+
if enhance:
|
| 84 |
+
v = v + target_scale * t
|
| 85 |
+
else:
|
| 86 |
+
cur_target_scale = clip.calculate_scale(v, t)
|
| 87 |
+
v = v + (target_scale - cur_target_scale) * t
|
| 88 |
+
return v
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def global_manipulate(clip, v, targets, target_scales, anchors, enhances):
|
| 92 |
+
targets = clip.encode_text(targets)
|
| 93 |
+
for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances):
|
| 94 |
+
v = global_manipulate_step(clip, v, t, s_t, a, enhance)
|
| 95 |
+
return v
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def assign_heatmap(s: torch.Tensor, threshold: float):
|
| 99 |
+
"""
|
| 100 |
+
The shape of input scales tensor should be (b, n, 1)
|
| 101 |
+
"""
|
| 102 |
+
maxm = s.max(dim=1, keepdim=True).values
|
| 103 |
+
minm = s.min(dim=1, keepdim=True).values
|
| 104 |
+
d = maxm - minm
|
| 105 |
+
return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3,
|
| 109 |
+
controls, targets, anchors, thresholds_list, target_scales, enhances):
|
| 110 |
+
model.low_vram_shift("cond")
|
| 111 |
+
clip = model.cond_stage_model
|
| 112 |
+
|
| 113 |
+
v = clip.encode(reference, "full")
|
| 114 |
+
if len(targets) > 0:
|
| 115 |
+
controls, targets = clip.encode_text(controls + targets).chunk(2)
|
| 116 |
+
inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances)
|
| 117 |
+
for c, t, a, target_scale, thresholds, enhance in inputs_iter:
|
| 118 |
+
# update image tokens
|
| 119 |
+
v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds)
|
| 120 |
+
token_length = v.shape[1] - 1
|
| 121 |
+
grid_num = int(token_length ** 0.5)
|
| 122 |
+
vis_c = clip.encode_text([vis_c])
|
| 123 |
+
local_v = v[:, 1:]
|
| 124 |
+
scale = clip.calculate_scale(local_v, vis_c)
|
| 125 |
+
scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num)
|
| 126 |
+
scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width)
|
| 127 |
+
|
| 128 |
+
# calculate heatmaps
|
| 129 |
+
heatmaps = []
|
| 130 |
+
for threshold in [ts0, ts1, ts2, ts3]:
|
| 131 |
+
heatmap = assign_heatmap(scale, threshold=threshold)
|
| 132 |
+
heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy()
|
| 133 |
+
heatmap = (heatmap * 255.).astype(np.uint8)
|
| 134 |
+
heatmaps.append(heatmap)
|
| 135 |
+
return heatmaps
|
refnet/sampling/sampler.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import torch
|
| 3 |
+
import k_diffusion
|
| 4 |
+
import inspect
|
| 5 |
+
|
| 6 |
+
from types import SimpleNamespace
|
| 7 |
+
from refnet.util import default
|
| 8 |
+
from .scheduler import schedulers, schedulers_map
|
| 9 |
+
from .denoiser import CFGDenoiser
|
| 10 |
+
|
| 11 |
+
defaults = SimpleNamespace(**{
|
| 12 |
+
"eta_ddim": 0.0,
|
| 13 |
+
"eta_ancestral": 1.0,
|
| 14 |
+
"ddim_discretize": "uniform",
|
| 15 |
+
"s_churn": 0.0,
|
| 16 |
+
"s_tmin": 0.0,
|
| 17 |
+
"s_noise": 1.0,
|
| 18 |
+
"k_sched_type": "Automatic",
|
| 19 |
+
"sigma_min": 0.0,
|
| 20 |
+
"sigma_max": 0.0,
|
| 21 |
+
"rho": 0.0,
|
| 22 |
+
"eta_noise_seed_delta": 0,
|
| 23 |
+
"always_discard_next_to_last_sigma": False,
|
| 24 |
+
})
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class Sampler:
|
| 28 |
+
label: str
|
| 29 |
+
funcname: str
|
| 30 |
+
aliases: any
|
| 31 |
+
options: dict
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
samplers_k_diffusion = [
|
| 35 |
+
Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
|
| 36 |
+
Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
| 37 |
+
Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
|
| 38 |
+
Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
| 39 |
+
Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
| 40 |
+
Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
| 41 |
+
Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
| 42 |
+
Sampler('Euler', 'sample_euler', ['k_euler'], {}),
|
| 43 |
+
Sampler('LMS', 'sample_lms', ['k_lms'], {}),
|
| 44 |
+
Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
| 45 |
+
Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
|
| 46 |
+
Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
| 47 |
+
Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
| 48 |
+
Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True})
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
sampler_extra_params = {
|
| 52 |
+
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
| 53 |
+
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
| 54 |
+
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
| 55 |
+
'sample_dpm_fast': ['s_noise'],
|
| 56 |
+
'sample_dpm_2_ancestral': ['s_noise'],
|
| 57 |
+
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
| 58 |
+
'sample_dpmpp_sde': ['s_noise'],
|
| 59 |
+
'sample_dpmpp_2m_sde': ['s_noise'],
|
| 60 |
+
'sample_dpmpp_3m_sde': ['s_noise'],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
def kdiffusion_sampler_list():
|
| 64 |
+
return [k.label for k in samplers_k_diffusion]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion}
|
| 68 |
+
k_diffusion_scheduler = {x.name: x.function for x in schedulers}
|
| 69 |
+
|
| 70 |
+
def exists(v):
|
| 71 |
+
return v is not None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class KDiffusionSampler:
|
| 75 |
+
def __init__(self, sampler, scheduler, sd, device):
|
| 76 |
+
# k_diffusion_samplers_map[]
|
| 77 |
+
self.config = k_diffusion_samplers_map[sampler]
|
| 78 |
+
funcname = self.config.funcname
|
| 79 |
+
|
| 80 |
+
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname)
|
| 81 |
+
self.scheduler_name = scheduler
|
| 82 |
+
self.sd = CFGDenoiser(sd, device)
|
| 83 |
+
self.model_wrap = self.sd.model_wrap
|
| 84 |
+
self.device = device
|
| 85 |
+
|
| 86 |
+
self.s_min_uncond = None
|
| 87 |
+
self.s_churn = 0.0
|
| 88 |
+
self.s_tmin = 0.0
|
| 89 |
+
self.s_tmax = float('inf')
|
| 90 |
+
self.s_noise = 1.0
|
| 91 |
+
|
| 92 |
+
self.eta_option_field = 'eta_ancestral'
|
| 93 |
+
self.eta_infotext_field = 'Eta'
|
| 94 |
+
self.eta_default = 1.0
|
| 95 |
+
self.eta = None
|
| 96 |
+
|
| 97 |
+
self.extra_params = []
|
| 98 |
+
|
| 99 |
+
if exists(sd.sigma_max) and exists(sd.sigma_min):
|
| 100 |
+
self.model_wrap.sigmas[-1] = sd.sigma_max
|
| 101 |
+
self.model_wrap.sigmas[0] = sd.sigma_min
|
| 102 |
+
|
| 103 |
+
def initialize(self):
|
| 104 |
+
self.eta = getattr(defaults, self.eta_option_field, 0.0)
|
| 105 |
+
|
| 106 |
+
extra_params_kwargs = {}
|
| 107 |
+
for param_name in self.extra_params:
|
| 108 |
+
if param_name in inspect.signature(self.func).parameters:
|
| 109 |
+
extra_params_kwargs[param_name] = getattr(self, param_name)
|
| 110 |
+
|
| 111 |
+
if 'eta' in inspect.signature(self.func).parameters:
|
| 112 |
+
extra_params_kwargs['eta'] = self.eta
|
| 113 |
+
|
| 114 |
+
if len(self.extra_params) > 0:
|
| 115 |
+
s_churn = getattr(defaults, 's_churn', self.s_churn)
|
| 116 |
+
s_tmin = getattr(defaults, 's_tmin', self.s_tmin)
|
| 117 |
+
s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf
|
| 118 |
+
s_noise = getattr(defaults, 's_noise', self.s_noise)
|
| 119 |
+
|
| 120 |
+
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
| 121 |
+
extra_params_kwargs['s_churn'] = s_churn
|
| 122 |
+
self.s_churn = s_churn
|
| 123 |
+
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
| 124 |
+
extra_params_kwargs['s_tmin'] = s_tmin
|
| 125 |
+
self.s_tmin = s_tmin
|
| 126 |
+
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
| 127 |
+
extra_params_kwargs['s_tmax'] = s_tmax
|
| 128 |
+
self.s_tmax = s_tmax
|
| 129 |
+
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
| 130 |
+
extra_params_kwargs['s_noise'] = s_noise
|
| 131 |
+
self.s_noise = s_noise
|
| 132 |
+
|
| 133 |
+
return extra_params_kwargs
|
| 134 |
+
|
| 135 |
+
def create_noise_sampler(self, x, sigmas, seed):
|
| 136 |
+
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
| 137 |
+
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
| 138 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 139 |
+
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
|
| 140 |
+
|
| 141 |
+
def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None):
|
| 142 |
+
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
| 143 |
+
|
| 144 |
+
steps += 1 if discard_next_to_last_sigma else 0
|
| 145 |
+
|
| 146 |
+
if self.scheduler_name == 'Automatic':
|
| 147 |
+
self.scheduler_name = self.config.options.get('scheduler', None)
|
| 148 |
+
|
| 149 |
+
scheduler = schedulers_map.get(self.scheduler_name)
|
| 150 |
+
sigma_min = default(sigmas_min, self.model_wrap.sigma_min)
|
| 151 |
+
sigma_max = default(sigmas_max, self.model_wrap.sigma_max)
|
| 152 |
+
|
| 153 |
+
if scheduler is None or scheduler.function is None:
|
| 154 |
+
sigmas = self.model_wrap.get_sigmas(steps)
|
| 155 |
+
else:
|
| 156 |
+
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
| 157 |
+
|
| 158 |
+
if scheduler.need_inner_model:
|
| 159 |
+
sigmas_kwargs['inner_model'] = self.model_wrap
|
| 160 |
+
|
| 161 |
+
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device)
|
| 162 |
+
|
| 163 |
+
if discard_next_to_last_sigma:
|
| 164 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
| 165 |
+
|
| 166 |
+
return sigmas
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None):
|
| 170 |
+
x = x * sigmas[0]
|
| 171 |
+
|
| 172 |
+
extra_params_kwargs = self.initialize()
|
| 173 |
+
parameters = inspect.signature(self.func).parameters
|
| 174 |
+
|
| 175 |
+
if 'n' in parameters:
|
| 176 |
+
extra_params_kwargs['n'] = steps
|
| 177 |
+
|
| 178 |
+
if 'sigma_min' in parameters:
|
| 179 |
+
extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min()
|
| 180 |
+
extra_params_kwargs['sigma_max'] = sigmas.max()
|
| 181 |
+
|
| 182 |
+
if 'sigmas' in parameters:
|
| 183 |
+
extra_params_kwargs['sigmas'] = sigmas
|
| 184 |
+
|
| 185 |
+
if self.config.options.get('brownian_noise', False):
|
| 186 |
+
noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None
|
| 187 |
+
extra_params_kwargs['noise_sampler'] = noise_sampler
|
| 188 |
+
|
| 189 |
+
if self.config.options.get('solver_type', None) == 'heun':
|
| 190 |
+
extra_params_kwargs['solver_type'] = 'heun'
|
| 191 |
+
|
| 192 |
+
return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs)
|
refnet/sampling/scheduler.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import k_diffusion
|
| 3 |
+
import dataclasses
|
| 4 |
+
|
| 5 |
+
@dataclasses.dataclass
|
| 6 |
+
class Scheduler:
|
| 7 |
+
name: str
|
| 8 |
+
label: str
|
| 9 |
+
function: any
|
| 10 |
+
|
| 11 |
+
default_rho: float = -1
|
| 12 |
+
need_inner_model: bool = False
|
| 13 |
+
aliases: list = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def uniform(n, sigma_min, sigma_max, inner_model, device):
|
| 17 |
+
return inner_model.get_sigmas(n)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
| 21 |
+
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
| 22 |
+
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
| 23 |
+
sigs = [
|
| 24 |
+
inner_model.t_to_sigma(ts)
|
| 25 |
+
for ts in torch.linspace(start, end, n + 1)[:-1]
|
| 26 |
+
]
|
| 27 |
+
sigs += [0.0]
|
| 28 |
+
return torch.FloatTensor(sigs).to(device)
|
| 29 |
+
|
| 30 |
+
schedulers = [
|
| 31 |
+
Scheduler('automatic', 'Automatic', None),
|
| 32 |
+
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
| 33 |
+
Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
|
| 34 |
+
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
|
| 35 |
+
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
|
| 36 |
+
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
def get_noise_schedulers():
|
| 40 |
+
return [scheduler.label for scheduler in schedulers]
|
| 41 |
+
|
| 42 |
+
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
refnet/sampling/tps_transformation.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Calculate warped image using control point manipulation on a thin plate (TPS)
|
| 3 |
+
Based on Herve Lombaert's 2006 web article
|
| 4 |
+
"Manual Registration with Thin Plates"
|
| 5 |
+
(https://profs.etsmtl.ca/hlombaert/thinplates/)
|
| 6 |
+
|
| 7 |
+
Implementation by Yucheol Jung <ycjung@postech.ac.kr>
|
| 8 |
+
'''
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import PIL.Image as Image
|
| 13 |
+
import torchvision.transforms as tf
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def tps_warp(images, num_points=10, perturbation_strength=10, random=True, pts_before=None, pts_after=None):
|
| 17 |
+
if random:
|
| 18 |
+
b, c, h, w = images.shape
|
| 19 |
+
device, dtype = images.device, images.dtype
|
| 20 |
+
pts_before = torch.rand([b, num_points, 2], dtype=dtype, device=device) * torch.Tensor([[[h, w]]]).to(device)
|
| 21 |
+
pts_after = pts_before + torch.randn([b, num_points, 2], dtype=dtype, device=device) * perturbation_strength
|
| 22 |
+
return _tps_warp(images, pts_before, pts_after)
|
| 23 |
+
|
| 24 |
+
def _tps_warp(im, pts_before, pts_after, normalize=True):
|
| 25 |
+
'''
|
| 26 |
+
Deforms image according to movement of pts_before and pts_after
|
| 27 |
+
|
| 28 |
+
Args)
|
| 29 |
+
im torch.Tensor object of size NxCxHxW
|
| 30 |
+
pts_before torch.Tensor object of size NxTx2 (T is # control pts)
|
| 31 |
+
pts_after torch.Tensor object of size NxTx2 (T is # control pts)
|
| 32 |
+
'''
|
| 33 |
+
# check input requirements
|
| 34 |
+
assert (4 == im.dim())
|
| 35 |
+
assert (3 == pts_after.dim())
|
| 36 |
+
assert (3 == pts_before.dim())
|
| 37 |
+
N = im.size()[0]
|
| 38 |
+
assert (N == pts_after.size()[0] and N == pts_before.size()[0])
|
| 39 |
+
assert (2 == pts_after.size()[2] and 2 == pts_before.size()[2])
|
| 40 |
+
T = pts_after.size()[1]
|
| 41 |
+
assert (T == pts_before.size()[1])
|
| 42 |
+
H = im.size()[2]
|
| 43 |
+
W = im.size()[3]
|
| 44 |
+
|
| 45 |
+
if normalize:
|
| 46 |
+
pts_after = pts_after.clone()
|
| 47 |
+
pts_after[:, :, 0] /= 0.5 * W
|
| 48 |
+
pts_after[:, :, 1] /= 0.5 * H
|
| 49 |
+
pts_after -= 1
|
| 50 |
+
pts_before = pts_before.clone()
|
| 51 |
+
pts_before[:, :, 0] /= 0.5 * W
|
| 52 |
+
pts_before[:, :, 1] /= 0.5 * H
|
| 53 |
+
pts_before -= 1
|
| 54 |
+
|
| 55 |
+
def construct_P():
|
| 56 |
+
'''
|
| 57 |
+
Consturcts matrix P of size NxTx3 where
|
| 58 |
+
P[n,i,0] := 1
|
| 59 |
+
P[n,i,1:] := pts_after[n]
|
| 60 |
+
'''
|
| 61 |
+
# Create matrix P with same configuration as 'pts_after'
|
| 62 |
+
P = pts_after.new_zeros((N, T, 3))
|
| 63 |
+
P[:, :, 0] = 1
|
| 64 |
+
P[:, :, 1:] = pts_after
|
| 65 |
+
|
| 66 |
+
return P
|
| 67 |
+
|
| 68 |
+
def calc_U(pt1, pt2):
|
| 69 |
+
'''
|
| 70 |
+
Calculate distance U between pt1 and pt2
|
| 71 |
+
|
| 72 |
+
U(r) := r**2 * log(r)
|
| 73 |
+
where
|
| 74 |
+
r := |pt1 - pt2|_2
|
| 75 |
+
|
| 76 |
+
Args)
|
| 77 |
+
pt1 torch.Tensor object, last dim is always 2
|
| 78 |
+
pt2 torch.Tensor object, last dim is always 2
|
| 79 |
+
'''
|
| 80 |
+
assert (2 == pt1.size()[-1])
|
| 81 |
+
assert (2 == pt2.size()[-1])
|
| 82 |
+
|
| 83 |
+
diff = pt1 - pt2
|
| 84 |
+
sq_diff = diff ** 2
|
| 85 |
+
sq_diff_sum = sq_diff.sum(-1)
|
| 86 |
+
r = sq_diff_sum.sqrt()
|
| 87 |
+
|
| 88 |
+
# Adds 1e-6 for numerical stability
|
| 89 |
+
return (r ** 2) * torch.log(r + 1e-6)
|
| 90 |
+
|
| 91 |
+
def construct_K():
|
| 92 |
+
'''
|
| 93 |
+
Consturcts matrix K of size NxTxT where
|
| 94 |
+
K[n,i,j] := U(|pts_after[n,i] - pts_after[n,j]|_2)
|
| 95 |
+
'''
|
| 96 |
+
|
| 97 |
+
# Assuming the number of control points are small enough,
|
| 98 |
+
# We just use for-loop for easy-to-read code
|
| 99 |
+
|
| 100 |
+
# Create matrix K with same configuration as 'pts_after'
|
| 101 |
+
K = pts_after.new_zeros((N, T, T))
|
| 102 |
+
for i in range(T):
|
| 103 |
+
for j in range(T):
|
| 104 |
+
K[:, i, j] = calc_U(pts_after[:, i, :], pts_after[:, j, :])
|
| 105 |
+
|
| 106 |
+
return K
|
| 107 |
+
|
| 108 |
+
def construct_L():
|
| 109 |
+
'''
|
| 110 |
+
Consturcts matrix L of size Nx(T+3)x(T+3) where
|
| 111 |
+
L[n] = [[ K[n] P[n] ]]
|
| 112 |
+
[[ P[n]^T 0 ]]
|
| 113 |
+
'''
|
| 114 |
+
P = construct_P()
|
| 115 |
+
K = construct_K()
|
| 116 |
+
|
| 117 |
+
# Create matrix L with same configuration as 'K'
|
| 118 |
+
L = K.new_zeros((N, T + 3, T + 3))
|
| 119 |
+
|
| 120 |
+
# Fill L matrix
|
| 121 |
+
L[:, :T, :T] = K
|
| 122 |
+
L[:, :T, T:(T + 3)] = P
|
| 123 |
+
L[:, T:(T + 3), :T] = P.transpose(1, 2)
|
| 124 |
+
|
| 125 |
+
return L
|
| 126 |
+
|
| 127 |
+
def construct_uv_grid():
|
| 128 |
+
'''
|
| 129 |
+
Returns H x W x 2 tensor uv with UV coordinate as its elements
|
| 130 |
+
uv[:,:,0] is H x W grid of x values
|
| 131 |
+
uv[:,:,1] is H x W grid of y values
|
| 132 |
+
'''
|
| 133 |
+
u_range = torch.arange(
|
| 134 |
+
start=-1.0, end=1.0, step=2.0 / W, device=im.device)
|
| 135 |
+
assert (W == u_range.size()[0])
|
| 136 |
+
u = u_range.new_zeros((H, W))
|
| 137 |
+
u[:] = u_range
|
| 138 |
+
|
| 139 |
+
v_range = torch.arange(
|
| 140 |
+
start=-1.0, end=1.0, step=2.0 / H, device=im.device)
|
| 141 |
+
assert (H == v_range.size()[0])
|
| 142 |
+
vt = v_range.new_zeros((W, H))
|
| 143 |
+
vt[:] = v_range
|
| 144 |
+
v = vt.transpose(0, 1)
|
| 145 |
+
|
| 146 |
+
return torch.stack([u, v], dim=2)
|
| 147 |
+
|
| 148 |
+
L = construct_L()
|
| 149 |
+
VT = pts_before.new_zeros((N, T + 3, 2))
|
| 150 |
+
# Use delta x and delta y as known heights of the surface
|
| 151 |
+
VT[:, :T, :] = pts_before - pts_after
|
| 152 |
+
|
| 153 |
+
# Solve Lx = VT
|
| 154 |
+
# x is of shape (N, T+3, 2)
|
| 155 |
+
# x[:,:,0] represents surface parameters for dx surface
|
| 156 |
+
# (dx values as surface height (z))
|
| 157 |
+
# x[:,:,1] represents surface parameters for dy surface
|
| 158 |
+
# (dy values as surface height (z))
|
| 159 |
+
x = torch.linalg.solve(L, VT)
|
| 160 |
+
|
| 161 |
+
uv = construct_uv_grid()
|
| 162 |
+
uv_batch = uv.repeat((N, 1, 1, 1))
|
| 163 |
+
|
| 164 |
+
def calc_dxdy():
|
| 165 |
+
'''
|
| 166 |
+
Calculate surface height for each uv coordinate
|
| 167 |
+
|
| 168 |
+
Returns NxHxWx2 tensor
|
| 169 |
+
'''
|
| 170 |
+
|
| 171 |
+
# control points of size NxTxHxWx2
|
| 172 |
+
cp = uv.new_zeros((H, W, N, T, 2))
|
| 173 |
+
cp[:, :, :] = pts_after
|
| 174 |
+
cp = cp.permute([2, 3, 0, 1, 4])
|
| 175 |
+
|
| 176 |
+
U = calc_U(uv, cp) # U value matrix of size NxTxHxW
|
| 177 |
+
w, a = x[:, :T, :], x[:, T:, :] # w is of size NxTx2, a is of size Nx3x2
|
| 178 |
+
w_x, w_y = w[:, :, 0], w[:, :, 1] # NxT each
|
| 179 |
+
a_x, a_y = a[:, :, 0], a[:, :, 1] # Nx3 each
|
| 180 |
+
dx = (
|
| 181 |
+
a_x[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
|
| 182 |
+
torch.einsum('nhwd,nd->nhw', uv_batch, a_x[:, 1:]) +
|
| 183 |
+
torch.einsum('nthw,nt->nhw', U, w_x)) # dx values of NxHxW
|
| 184 |
+
dy = (
|
| 185 |
+
a_y[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
|
| 186 |
+
torch.einsum('nhwd,nd->nhw', uv_batch, a_y[:, 1:]) +
|
| 187 |
+
torch.einsum('nthw,nt->nhw', U, w_y)) # dy values of NxHxW
|
| 188 |
+
|
| 189 |
+
return torch.stack([dx, dy], dim=3)
|
| 190 |
+
|
| 191 |
+
dxdy = calc_dxdy()
|
| 192 |
+
flow_field = uv + dxdy
|
| 193 |
+
|
| 194 |
+
return F.grid_sample(im, flow_field.to(im.dtype))
|
| 195 |
+
|
| 196 |
+
if __name__ == '__main__':
|
| 197 |
+
num_points = 10
|
| 198 |
+
perturbation_strength = 10
|
| 199 |
+
img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB")).unsqueeze(0)
|
| 200 |
+
# img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB").resize((224, 224))).unsqueeze(0)
|
| 201 |
+
img = tps_warp(img, num_points= num_points, perturbation_strength = perturbation_strength).squeeze(0)
|
| 202 |
+
img = tf.ToPILImage()(img)
|
| 203 |
+
img.show()
|