diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1339be116951e45a93f4e67ceda40cd8d78dee63
--- /dev/null
+++ b/app.py
@@ -0,0 +1,227 @@
+import gradio as gr
+import argparse
+
+from refnet.sampling import get_noise_schedulers, get_sampler_list
+from functools import partial
+from backend import *
+
+links = {
+ "base": "https://arxiv.org/abs/2401.01456",
+ "v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html",
+ "v1.5": "https://arxiv.org/abs/2502.19937v1",
+ "v2": "https://arxiv.org/abs/2504.06895",
+ "xl": "https://arxiv.org/abs/2601.04883",
+ "weights": "https://huggingface.co/tellurion/colorizer/tree/main",
+ "github": "https://github.com/tellurion-kanata/colorizeDiffusion",
+}
+
+def app_options():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0")
+ parser.add_argument("--server_port", '-port', type=int, default=7860)
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true")
+ return parser.parse_args()
+
+
+def init_interface(opt, *args, **kwargs) -> None:
+ sampler_list = get_sampler_list()
+ scheduler_list = get_noise_schedulers()
+
+ img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png")
+ with gr.Blocks(
+ title = "Colorize Diffusion",
+ css_paths = "backend/style.css",
+ theme = gr.themes.Ocean(),
+ elem_id = "main-interface",
+ analytics_enabled = False,
+ fill_width = True
+ ) as block:
+ with gr.Row(elem_id="header-row", equal_height=True, variant="panel"):
+ gr.Markdown(f"""
""")
+
+ with gr.Row(elem_id="content-row", equal_height=False, variant="panel"):
+ with gr.Column():
+ with gr.Row(visible=opt.enable_text_manipulation):
+ target = gr.Textbox(label="Target prompt", value="", scale=2)
+ anchor = gr.Textbox(label="Anchor prompt", value="", scale=2)
+ control = gr.Textbox(label="Control prompt", value="", scale=2)
+ with gr.Row(visible=opt.enable_text_manipulation):
+ target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2)
+ ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01)
+ ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01)
+ ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01)
+ ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01)
+ with gr.Row(visible=opt.enable_text_manipulation):
+ enhance = gr.Checkbox(label="Enhance manipulation", value=False)
+ add_prompt = gr.Button(value="Add")
+ clear_prompt = gr.Button(value="Clear")
+ vis_button = gr.Button(value="Visualize")
+ text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation)
+
+ with gr.Row():
+ sketch_img = img_block(label="Sketch")
+ reference_img = img_block(label="Reference")
+ background_img = img_block(label="Background")
+
+ style_enhance = gr.State(False)
+ fg_enhance = gr.State(False)
+ with gr.Row():
+ bg_enhance = gr.Checkbox(label="Low-level injection", value=False)
+ injection = gr.Checkbox(label="Attention injection", value=False)
+ autofit_size = gr.Checkbox(label="Autofit size", value=False)
+ with gr.Row():
+ gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5)
+ strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05)
+ fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05)
+ bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05)
+ with gr.Row():
+ gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1)
+ ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05)
+ mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05)
+ merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05)
+ with gr.Row():
+ bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1)
+ width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2)
+ with gr.Row():
+ step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1)
+ height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2)
+
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1)
+ with gr.Accordion("Advanced Settings", open=False):
+ with gr.Row():
+ crop = gr.Checkbox(label="Crop result", value=False, scale=1)
+ remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2)
+ rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2)
+ latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2)
+ with gr.Row():
+ injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0,
+ maximum=2.0, value=0, step=0.05)
+ injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0,
+ maximum=1.0, value=0.5, step=0.05)
+ injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0,
+ value=0, step=0.05)
+
+ with gr.Row():
+ reuse_seed = gr.Button(value="Reuse Seed")
+ random_seed = gr.Button(value="Random Seed")
+
+ with gr.Column():
+ result_gallery = gr.Gallery(
+ label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png"
+ )
+ run_button = gr.Button("Generate", variant="primary", size="lg")
+ with gr.Row():
+ mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01)
+ mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01)
+ pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05)
+
+ with gr.Row():
+ sd_model = gr.Dropdown(choices=get_available_models(), label="Models",
+ value=get_available_models()[0])
+ extractor_model = gr.Dropdown(choices=line_extractor_list,
+ label="Line extractor", value=default_line_extractor)
+ mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor",
+ value=default_mask_extractor)
+ with gr.Row():
+ sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler")
+ scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler")
+ preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"],
+ label="Sketch preprocessor", value="invert")
+
+ with gr.Row():
+ deterministic = gr.Checkbox(label="Deterministic batch seed", value=False)
+ save_memory = gr.Checkbox(label="Save memory", value=True)
+
+ # Hidden states for unused advanced controls
+ fg_disentangle_scale = gr.State(1.0)
+ start_step = gr.State(0.0)
+ end_step = gr.State(1.0)
+ no_start_step = gr.State(-0.05)
+ no_end_step = gr.State(-0.05)
+ return_inter = gr.State(False)
+ accurate = gr.State(False)
+ enc_scale = gr.State(1.0)
+ middle_scale = gr.State(1.0)
+ low_scale = gr.State(1.0)
+ ctl_scale_1 = gr.State(1.0)
+ ctl_scale_2 = gr.State(1.0)
+ ctl_scale_3 = gr.State(1.0)
+ ctl_scale_4 = gr.State(1.0)
+
+ add_prompt.click(fn=apppend_prompt,
+ inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt],
+ outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt])
+ clear_prompt.click(fn=clear_prompts, outputs=[text_prompt])
+
+ reuse_seed.click(fn=get_last_seed, outputs=[seed])
+ random_seed.click(fn=reset_random_seed, outputs=[seed])
+
+ extractor_model.input(fn=switch_extractor, inputs=[extractor_model])
+ sd_model.input(fn=load_model, inputs=[sd_model])
+ mask_model.input(fn=switch_mask_extractor, inputs=[mask_model])
+
+ ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
+ bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
+ ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale,
+ mask_scale, height, width, seed, save_memory, step, injection, autofit_size,
+ remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step,
+ crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler,
+ preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3,
+ enhance, accurate, enc_scale, middle_scale, low_scale, strength]
+
+ run_button.click(
+ fn = inference,
+ inputs = ips,
+ outputs = [result_gallery],
+ )
+
+ vis_button.click(
+ fn = visualize,
+ inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3],
+ outputs = [result_gallery],
+ )
+
+ block.launch(
+ server_name = opt.server_name,
+ share = opt.share,
+ server_port = opt.server_port,
+ )
+
+
+if __name__ == '__main__':
+ opt = app_options()
+ try:
+ models = get_available_models()
+ load_model(models[0])
+ switch_extractor(default_line_extractor)
+ switch_mask_extractor(default_mask_extractor)
+ interface = init_interface(opt)
+ except Exception as e:
+ print(f"Error initializing interface: {e}")
+ raise
diff --git a/backend/__init__.py b/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b265edbc6413b6c5d7871b5d7dcfb15067c2be1a
--- /dev/null
+++ b/backend/__init__.py
@@ -0,0 +1,16 @@
+from .appfunc import *
+
+
+__all__ = [
+ 'switch_extractor', 'switch_mask_extractor',
+ 'get_available_models', 'load_model', 'inference', 'reset_random_seed', 'get_last_seed',
+ 'apppend_prompt', 'clear_prompts', 'visualize',
+ 'default_line_extractor', 'default_mask_extractor', 'MAXM_INT32',
+ 'mask_extractor_list', 'line_extractor_list',
+]
+
+
+default_line_extractor = "lineart_keras"
+default_mask_extractor = "rmbg-v2"
+mask_extractor_list = ["none", "ISNet", "rmbg-v2", "BiRefNet", "BiRefNet_HR"]
+line_extractor_list = ["lineart", "lineart_denoise", "lineart_keras", "lineart_sk"]
diff --git a/backend/appfunc.py b/backend/appfunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba75de23089a17cce6a9edd15f1dcf520895a963
--- /dev/null
+++ b/backend/appfunc.py
@@ -0,0 +1,298 @@
+import os
+import random
+import traceback
+import gradio as gr
+import os.path as osp
+
+from huggingface_hub import hf_hub_download
+
+from omegaconf import OmegaConf
+from refnet.util import instantiate_from_config
+from preprocessor import create_model
+from .functool import *
+
+model = None
+
+model_type = ""
+current_checkpoint = ""
+global_seed = None
+
+smask_extractor = create_model("ISNet-sketch").cpu()
+
+MAXM_INT32 = 429496729
+
+# HuggingFace model repository
+HF_REPO_ID = "tellurion/colorizer"
+MODEL_CACHE_DIR = "models"
+
+# Model registry: filename -> model_type
+MODEL_REGISTRY = {
+ "sdxl.safetensors": "sdxl",
+ "xlv2.safetensors": "xlv2",
+}
+
+model_types = ["sdxl", "xlv2"]
+
+'''
+ Gradio UI functions
+'''
+
+
+def get_available_models():
+ """Return list of available model names from registry."""
+ return list(MODEL_REGISTRY.keys())
+
+
+def download_model(filename):
+ """Download a model from HuggingFace Hub if not already cached."""
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
+ local_path = osp.join(MODEL_CACHE_DIR, filename)
+ if osp.exists(local_path):
+ return local_path
+
+ print(f"Downloading {filename} from {HF_REPO_ID}...")
+ gr.Info(f"Downloading {filename}...")
+ path = hf_hub_download(
+ repo_id=HF_REPO_ID,
+ filename=filename,
+ local_dir=MODEL_CACHE_DIR,
+ )
+ print(f"Downloaded to {path}")
+ return path
+
+
+def switch_extractor(type):
+ global line_extractor
+ try:
+ line_extractor = create_model(type)
+ gr.Info(f"Switched to {type} extractor")
+ except Exception as e:
+ print(f"Error info: {e}")
+ print(traceback.print_exc())
+ gr.Info(f"Failed in loading {type} extractor")
+
+
+def switch_mask_extractor(type):
+ global mask_extractor
+ try:
+ mask_extractor = create_model(type)
+ gr.Info(f"Switched to {type} extractor")
+ except Exception as e:
+ print(f"Error info: {e}")
+ print(traceback.print_exc())
+ gr.Info(f"Failed in loading {type} extractor")
+
+
+def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt):
+ target = target.strip()
+ anchor = anchor.strip()
+ control = control.strip()
+ if target == "": target = "none"
+ if anchor == "": anchor = "none"
+ if control == "": control = "none"
+ new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; "
+ f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}")
+ return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip()
+
+
+def clear_prompts():
+ return ""
+
+
+def load_model(ckpt_name):
+ global model, model_type, current_checkpoint
+ config_root = "configs/inference"
+
+ try:
+ # Determine model type from registry or filename prefix
+ new_model_type = MODEL_REGISTRY.get(ckpt_name, "")
+ if not new_model_type:
+ for key in model_types:
+ if ckpt_name.startswith(key):
+ new_model_type = key
+ break
+
+ if model_type != new_model_type or not "model" in globals():
+ if "model" in globals() and exists(model):
+ del model
+ config_path = osp.join(config_root, f"{new_model_type}.yaml")
+ new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval()
+ print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...")
+ model = new_model
+
+ # Download model from HF Hub
+ local_path = download_model(ckpt_name)
+
+ model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v"
+ model.init_from_ckpt(local_path, logging=True)
+ model.switch_to_fp16()
+
+ model_type = new_model_type
+ current_checkpoint = ckpt_name
+ print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].")
+ gr.Info("Loaded model successfully.")
+
+ except Exception as e:
+ print(f"Error type: {e}")
+ print(traceback.print_exc())
+ gr.Info("Failed in loading model.")
+
+
+def get_last_seed():
+ return global_seed or -1
+
+
+def reset_random_seed():
+ return -1
+
+
+def visualize(reference, text, *args):
+ return visualize_heatmaps(model, reference, parse_prompts(text), *args)
+
+
+def set_cas_scales(accurate, cas_args):
+ enc_scale, middle_scale, low_scale, strength = cas_args[:4]
+ if not accurate:
+ scale_strength = {
+ "level_control": True,
+ "scales": {
+ "encoder": enc_scale * strength,
+ "middle": middle_scale * strength,
+ "low": low_scale * strength,
+ }
+ }
+ else:
+ scale_strength = {
+ "level_control": False,
+ "scales": list(cas_args[4:])
+ }
+ return scale_strength
+
+
+@torch.no_grad()
+def inference(
+ style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale,
+ bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale,
+ ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4,
+ fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step,
+ injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale,
+ start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess,
+ deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate,
+ *args
+):
+ global global_seed, line_extractor, mask_extractor
+ global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32)
+ torch.manual_seed(global_seed)
+
+ # Auto-fit size based on sketch dimensions
+ if autofit_size and exists(input_s):
+ sketch_w, sketch_h = input_s.size
+ aspect_ratio = sketch_w / sketch_h
+ target_area = 1024 * 1024
+ new_h = int((target_area / aspect_ratio) ** 0.5)
+ new_w = int(new_h * aspect_ratio)
+ height = ((new_h + 16) // 32) * 32
+ width = ((new_w + 16) // 32) * 32
+ height = max(768, min(1536, height))
+ width = max(768, min(1536, width))
+ gr.Info(f"Auto-fitted size: {width}x{height}")
+
+ smask, rmask, bgmask = None, None, None
+ manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance)
+ inputs = preprocessing_inputs(
+ sketch = input_s,
+ reference = input_r,
+ background = input_bg,
+ preprocess = preprocess,
+ hook = injection,
+ resolution = (height, width),
+ extractor = line_extractor,
+ pad_scale = pad_scale,
+ )
+ sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs
+
+ cond = {"reference": reference, "sketch": sketch, "background": background}
+ mask_guided = bg_enhance or fg_enhance
+
+ if exists(white_sketch) and exists(reference) and mask_guided:
+ mask_extractor.cuda()
+ smask_extractor.cuda()
+ smask = smask_extractor.proceed(
+ x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False
+ )
+
+ if exists(background) and remove_fg:
+ bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True)
+ filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background))
+ cond.update({"background": filtered_background, "rmask": bgmask})
+ else:
+ rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True)
+ cond.update({"rmask": rmask})
+ rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask))
+ cond.update({"smask": smask})
+ smask_extractor.cpu()
+ mask_extractor.cpu()
+
+ scale_strength = set_cas_scales(accurate, args)
+ ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4]
+ ctl_scales = [t * ctl_scale for t in ctl_scales]
+
+ results = model.generate(
+ # Colorization mode
+ style_enhance = style_enhance,
+ bg_enhance = bg_enhance,
+ fg_enhance = fg_enhance,
+ fg_disentangle_scale = fg_disentangle_scale,
+ latent_inpaint = latent_inpaint,
+
+ # Conditional inputs
+ cond = cond,
+ ctl_scale = ctl_scales,
+ merge_scale = merge_scale,
+ mask_scale = mask_scale,
+ mask_thresh = mask_ts,
+ mask_thresh_sketch = mask_ss,
+
+ # Sampling settings
+ bs = bs,
+ gs = [gs_r, gs_s],
+ sampler = sampler,
+ scheduler = scheduler,
+ start_step = start_step,
+ end_step = end_step,
+ no_start_step = no_start_step,
+ no_end_step = no_end_step,
+ strength = scale_strength,
+ fg_strength = fg_strength,
+ bg_strength = bg_strength,
+ seed = global_seed,
+ deterministic = deterministic,
+ height = height,
+ width = width,
+ step = step,
+
+ # Injection settings
+ injection = injection,
+ injection_cfg = infid_r,
+ injection_control = infid_x,
+ injection_start_step = injstep,
+ hook_xr = inject_xr,
+ hook_xs = inject_xs,
+
+ # Additional settings
+ low_vram = low_vram,
+ return_intermediate = return_inter,
+ manipulation_params = manipulation_params,
+ )
+
+ if rmbg:
+ mask_extractor.cuda()
+ mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1)
+ results = torch.where(mask >= mask_ss, results, torch.ones_like(results))
+ mask_extractor.cpu()
+
+ results = postprocess(results, sketch, reference, background, crop, original_shape,
+ mask_guided, smask, rmask, bgmask, mask_ts, mask_ss)
+ torch.cuda.empty_cache()
+ gr.Info("Generation completed.")
+ return results
diff --git a/backend/functool.py b/backend/functool.py
new file mode 100644
index 0000000000000000000000000000000000000000..45db522c469d10dba1dab2459cc4e022e8a7263c
--- /dev/null
+++ b/backend/functool.py
@@ -0,0 +1,276 @@
+import cv2
+import numpy as np
+import PIL.Image as Image
+
+import torch
+import torch.nn as nn
+import torchvision.transforms as transforms
+
+from functools import partial
+
+maxium_resolution = 4096
+token_length = int(256 ** 0.5)
+
+def exists(v):
+ return v is not None
+
+resize = partial(transforms.Resize, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
+
+def resize_image(img, new_size, w, h):
+ if w > h:
+ img = resize((int(h / w * new_size), new_size))(img)
+ else:
+ img = resize((new_size, int(w / h * new_size)))(img)
+ return img
+
+def pad_image(image: torch.Tensor, h, w):
+ b, c, height, width = image.shape
+ square_image = -torch.ones([b, c, h, w], device=image.device)
+ left = (w - width) // 2
+ top = (h - height) // 2
+ square_image[:, :, top:top+height, left:left+width] = image
+
+ return square_image, (left, top, width, height)
+
+
+def pad_image_with_margin(image: Image, scale):
+ w, h = image.size
+ nw = int(w * scale)
+ bg = Image.new('RGB', (nw, h), (255, 255, 255))
+ bg.paste(image, ((nw-w)//2, 0))
+ return bg
+
+
+def crop_image_from_square(square_image, original_dim):
+ left, top, width, height = original_dim
+ return square_image.crop((left, top, left + width, top + height))
+
+
+def to_tensor(x, inverse=False):
+ x = transforms.ToTensor()(x).unsqueeze(0)
+ x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
+ return x if not inverse else -x
+
+def to_numpy(x, denormalize=True):
+ if denormalize:
+ return ((x.clamp(-1, 1) + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
+ else:
+ return (x.clamp(0, 1) * 255)[0][0].cpu().numpy().astype(np.uint8)
+
+def lineart_standard(x: Image.Image):
+ x = np.array(x).astype(np.float32)
+ g = cv2.GaussianBlur(x, (0, 0), 6.0)
+ intensity = np.min(g - x, axis=2).clip(0, 255)
+ intensity /= max(16, np.median(intensity[intensity > 8]))
+ intensity *= 127
+ intensity = np.repeat(np.expand_dims(intensity, 2), 3, axis=2)
+ result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
+ return result
+
+def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new=False):
+ w, h = sketch.size
+ th, tw = resolution
+ r = min(th/h, tw/w)
+
+ if preprocess == "none":
+ sketch = to_tensor(sketch)
+ elif preprocess == "invert":
+ sketch = to_tensor(sketch, inverse=True)
+ elif preprocess == "invert-webui":
+ sketch = lineart_standard(sketch)
+ else:
+ sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
+
+ sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
+ if new:
+ sketch = ((sketch + 1) / 2.).clamp(0, 1)
+ white_sketch = 1 - sketch
+ else:
+ white_sketch = -sketch
+ return sketch, original_shape, white_sketch
+
+
+@torch.no_grad()
+def preprocessing_inputs(
+ sketch: Image.Image,
+ reference: Image.Image,
+ background: Image.Image,
+ preprocess: str,
+ hook: bool,
+ resolution: tuple[int, int],
+ extractor: nn.Module,
+ pad_scale: float = 1.,
+ new = False
+):
+ extractor = extractor.cuda()
+ h, w = resolution
+ if exists(sketch):
+ sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor, new)
+ else:
+ sketch = torch.zeros([1, 3, h, w], device="cuda") if new else -torch.ones([1, 3, h, w], device="cuda")
+ white_sketch = None
+ original_shape = (0, 0, h, w)
+
+ inject_xs = None
+ if hook:
+ assert exists(reference) and exists(extractor)
+ maxm = max(h, w)
+ # inject_xs = resize((h, w))(extractor.proceed(resize((maxm, maxm))(reference)).repeat(1, 3, 1, 1))
+ inject_xr = to_tensor(resize((h, w))(reference))
+ else:
+ inject_xr = None
+ extractor = extractor.cpu()
+
+ if exists(reference):
+ if pad_scale > 1.:
+ reference = pad_image_with_margin(reference, pad_scale)
+ reference = to_tensor(reference)
+
+ if exists(background):
+ if pad_scale > 1.:
+ background = pad_image_with_margin(background, pad_scale)
+ background = to_tensor(background)
+
+ return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
+
+def postprocess(results, sketch, reference, background, crop, original_shape,
+ mask_guided, smask, rmask, bgmask, mask_ts, mask_ss, new=False):
+ results = to_numpy(results)
+ sketch = to_numpy(sketch, not new)[0]
+
+ results_list = []
+ for result in results:
+ result = Image.fromarray(result)
+ if crop:
+ result = crop_image_from_square(result, original_shape)
+ results_list.append(result)
+
+ results_list.append(sketch)
+
+ if exists(reference):
+ reference = to_numpy(reference)[0]
+ results_list.append(reference)
+ # if vis_crossattn:
+ # results_list += visualize_attention_map(reference, results_list[0], vh, vw)
+
+ if exists(background):
+ background = to_numpy(background)[0]
+ results_list.append(background)
+
+ if exists(bgmask):
+ background = Image.fromarray(background)
+ results_list.append(Image.composite(
+ background,
+ Image.new("RGB", background.size, (255, 255, 255)),
+ Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
+ ))
+ results_list.append(Image.composite(
+ Image.new("RGB", background.size, (255, 255, 255)),
+ background,
+ Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
+ ))
+
+ if mask_guided:
+ smask[smask < mask_ss] = 0
+ results_list.append(Image.fromarray(to_numpy(smask, denormalize=False), mode="L"))
+
+ if exists(rmask):
+ reference = Image.fromarray(reference)
+ rmask[rmask < mask_ts] = 0
+ results_list.append(Image.fromarray(to_numpy(rmask, denormalize=False), mode="L"))
+ results_list.append(Image.composite(
+ reference,
+ Image.new("RGB", reference.size, (255, 255, 255)),
+ Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
+ ))
+ results_list.append(Image.composite(
+ Image.new("RGB", reference.size, (255, 255, 255)),
+ reference,
+ Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
+ ))
+
+ return results_list
+
+
+def parse_prompts(
+ prompts: str,
+ target: bool = None,
+ anchor: bool = None,
+ control: bool = None,
+ target_scale: bool = None,
+ ts0: float = None,
+ ts1: float = None,
+ ts2: float = None,
+ ts3: float = None,
+ enhance: bool = None
+):
+
+ targets = []
+ anchors = []
+ controls = []
+ scales = []
+ enhances = []
+ thresholds_list = []
+
+ replace_str = ["; [anchor] ", "; [control] ", "; [scale]", "; [enhanced]", "; [ts0]", "; [ts1]", "; [ts2]", "; [ts3]"]
+ if prompts != "" and prompts is not None:
+ ps_l = prompts.split('\n')
+ for ps in ps_l:
+ ps = ps.replace("[target] ", "")
+ for str in replace_str:
+ ps = ps.replace(str, "||||")
+
+ p_l = ps.split("||||")
+ targets.append(p_l[0])
+ anchors.append(p_l[1])
+ controls.append(p_l[2])
+ scales.append(float(p_l[3]))
+ enhances.append(bool(p_l[4]))
+ thresholds_list.append([float(p_l[5]), float(p_l[6]), float(p_l[7]), float(p_l[8])])
+
+ if exists(target) and target != "":
+ targets.append(target)
+ anchors.append(anchor)
+ controls.append(control)
+ scales.append(target_scale)
+ enhances.append(enhance)
+ thresholds_list.append([ts0, ts1, ts2, ts3])
+
+ return {
+ "targets": targets,
+ "anchors": anchors,
+ "controls": controls,
+ "target_scales": scales,
+ "enhances": enhances,
+ "thresholds_list": thresholds_list
+ }
+
+
+from refnet.sampling.manipulation import get_heatmaps
+def visualize_heatmaps(model, reference, manipulation_params, control, ts0, ts1, ts2, ts3):
+ if reference is None:
+ return []
+
+ size = reference.size
+ if size[0] > maxium_resolution or size[1] > maxium_resolution:
+ if size[0] > size[1]:
+ size = (maxium_resolution, int(float(maxium_resolution) / size[0] * size[1]))
+ else:
+ size = (int(float(maxium_resolution) / size[1] * size[0]), maxium_resolution)
+ reference = reference.resize(size, Image.BICUBIC)
+
+ reference = np.array(reference)
+ scale_maps = get_heatmaps(model, to_tensor(reference), size[1], size[0],
+ control, ts0, ts1, ts2, ts3, **manipulation_params)
+
+ scale_map = scale_maps[0] + scale_maps[1] + scale_maps[2] + scale_maps[3]
+ heatmap = cv2.cvtColor(cv2.applyColorMap(scale_map, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
+ result = cv2.addWeighted(reference, 0.3, heatmap, 0.7, 0)
+ hu = size[1] // token_length
+ wu = size[0] // token_length
+ for i in range(16):
+ result[i * hu, :] = (0, 0, 0)
+ for i in range(16):
+ result[:, i * wu] = (0, 0, 0)
+
+ return [result]
\ No newline at end of file
diff --git a/backend/style.css b/backend/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..8b9e29512a265bf4ca7e34ec4bc982a59d1421ff
--- /dev/null
+++ b/backend/style.css
@@ -0,0 +1,181 @@
+:root {
+ --primary-color: #9b59b6;
+ --primary-light: #d6c6e1;
+ --secondary-color: #2ecc71;
+ --text-color: #333333;
+ --background-color: #f9f9f9;
+ --card-bg: #ffffff;
+ --border-radius: 10px;
+ --shadow-sm: 0 2px 5px rgba(0, 0, 0, 0.05);
+ --shadow-md: 0 5px 15px rgba(0, 0, 0, 0.07);
+ --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1);
+ --gradient: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
+ --input-border: #e0e0e0;
+ --input-bg: #ffffff;
+ --font-weight-normal: 500;
+ --font-weight-bold: 700;
+}
+
+/* Base styles */
+body, html {
+ margin: 0;
+ padding: 0;
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
+ font-weight: var(--font-weight-normal);
+ background-color: var(--background-color);
+ color: var(--text-color);
+ width: 100vw;
+ overflow-x: hidden;
+}
+
+* {
+ box-sizing: border-box;
+}
+
+/* Force full width layout */
+#main-interface,
+.gradio-app,
+.gradio-container {
+ width: 100vw !important;
+ max-width: 100vw !important;
+ margin: 0 !important;
+ padding: 0 !important;
+ box-shadow: none !important;
+ border: none !important;
+ overflow-x: hidden !important;
+}
+
+/* Header styling */
+#header-row {
+ background: white;
+ padding: 15px 20px;
+ margin-bottom: 20px;
+ box-shadow: var(--shadow-sm);
+ border-bottom: 1px solid rgba(0,0,0,0.05);
+}
+
+.header-container {
+ width: 100%;
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ padding: 10px 0;
+}
+
+.app-header {
+ display: flex;
+ align-items: center;
+ gap: 12px;
+ margin-bottom: 15px;
+}
+
+.app-header .emoji {
+ font-size: 36px;
+}
+
+/* Fix for Colorize Diffusion title visibility */
+.gradio-markdown h1,
+.gradio-markdown h2,
+#header-row h1,
+#header-row h2,
+.title-text,
+.app-header .title-text {
+ display: inline-block !important;
+ visibility: visible !important;
+ opacity: 1 !important;
+ position: relative !important;
+ color: var(--primary-color) !important;
+ font-size: 32px !important;
+ font-weight: 800 !important;
+}
+
+/* Badge links under the header */
+.paper-links-icons {
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: center;
+ gap: 8px;
+ margin-top: 5px;
+}
+
+.paper-links-icons a {
+ transition: transform 0.2s ease;
+ opacity: 0.9;
+}
+
+.paper-links-icons a:hover {
+ transform: translateY(-3px);
+ opacity: 1;
+}
+
+/* Content layout */
+#content-row {
+ padding: 0 20px 20px 20px;
+ max-width: 100%;
+ margin: 0 auto;
+}
+
+/* Apply bold font to all text elements for better readability */
+p, span, label, button, input, textarea, select, .gradio-button, .gradio-checkbox, .gradio-dropdown, .gradio-textbox {
+ font-weight: var(--font-weight-normal);
+}
+
+/* Make headings bolder */
+h1, h2, h3, h4, h5, h6 {
+ font-weight: var(--font-weight-bold);
+}
+
+/* Improved font styling for Gradio UI elements */
+.gradio-container,
+.gradio-container *,
+.gradio-app,
+.gradio-app * {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 500 !important;
+}
+
+/* Style for labels and slider labels */
+.gradio-container label,
+.gradio-slider label,
+.gradio-checkbox label,
+.gradio-radio label,
+.gradio-dropdown label,
+.gradio-textbox label,
+.gradio-number label,
+.gradio-button,
+.gradio-checkbox span,
+.gradio-radio span {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 600 !important;
+ letter-spacing: 0.01em;
+}
+
+/* Style for buttons */
+button,
+.gradio-button {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 600 !important;
+}
+
+/* Style for input values */
+input,
+textarea,
+select,
+.gradio-textbox textarea,
+.gradio-number input {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 500 !important;
+}
+
+/* Better styling for drop areas */
+.upload-box,
+[data-testid="image"] {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 500 !important;
+}
+
+/* Additional styling for values in sliders and numbers */
+.wrap .wrap .wrap span {
+ font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
+ font-weight: 600 !important;
+}
diff --git a/configs/inference/sdxl.yaml b/configs/inference/sdxl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bc9d373f7428ab9c8567083f3c4ae215e70e5cc4
--- /dev/null
+++ b/configs/inference/sdxl.yaml
@@ -0,0 +1,88 @@
+model:
+ base_learning_rate: 1.0e-6
+ target: refnet.models.colorizerXL.InferenceWrapper
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ timesteps: 1000
+ image_size: 128
+ channels: 4
+ scale_factor: 0.13025
+ logits_embed: false
+
+ unet_config:
+ target: refnet.modules.unet.DualCondUNetXL
+ params:
+ use_checkpoint: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ adm_in_channels: 512
+# adm_in_channels: 2816
+ num_classes: sequential
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4]
+ num_head_channels: 64
+ use_spatial_transformer: true
+ use_linear_in_transformer: true
+ transformer_depth: [1, 2, 10]
+ context_dim: 2048
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 512
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+
+ cond_stage_config:
+ target: refnet.modules.embedder.HFCLIPVisionModel
+ # target: refnet.modules.embedder.FrozenOpenCLIPImageEmbedder
+ params:
+ arch: ViT-bigG-14
+
+ control_encoder_config:
+# target: refnet.modules.encoder.MultiEncoder
+ target: refnet.modules.encoder.MultiScaleAttentionEncoder
+ params:
+ in_ch: 3
+ model_channels: 320
+ ch_mults: [ 1, 2, 4 ]
+
+ img_embedder_config:
+ target: refnet.modules.embedder.WDv14SwinTransformerV2
+
+ scalar_embedder_config:
+ target: refnet.modules.embedder.TimestepEmbedding
+ params:
+ embed_dim: 256
+
+ proj_config:
+ target: refnet.modules.proj.ClusterConcat
+# target: refnet.modules.proj.RecoveryClusterConcat
+ params:
+ input_dim: 1280
+ c_dim: 1024
+ output_dim: 2048
+ token_length: 196
+ dim_head: 128
+# proj_config:
+# target: refnet.modules.proj.LogitClusterConcat
+# params:
+# input_dim: 1280
+# c_dim: 1024
+# output_dim: 2048
+# token_length: 196
+# dim_head: 128
+# mlp_in_dim: 9083
+# mlp_ckpt_path: pretrained_models/proj.safetensors
\ No newline at end of file
diff --git a/configs/inference/xlv2.yaml b/configs/inference/xlv2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..035b09df80ad2082bef3a352a399ea98b706ab19
--- /dev/null
+++ b/configs/inference/xlv2.yaml
@@ -0,0 +1,108 @@
+model:
+ base_learning_rate: 1.0e-6
+ target: refnet.models.v2-colorizerXL.InferenceWrapperXL
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ timesteps: 1000
+ image_size: 128
+ channels: 4
+ scale_factor: 0.13025
+ controller: true
+
+ unet_config:
+ target: refnet.modules.unet.DualCondUNetXL
+ params:
+ use_checkpoint: True
+ in_channels: 4
+ in_channels_fg: 4
+ out_channels: 4
+ model_channels: 320
+ adm_in_channels: 512
+ num_classes: sequential
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4]
+ num_head_channels: 64
+ use_spatial_transformer: true
+ use_linear_in_transformer: true
+ transformer_depth: [1, 2, 10]
+ context_dim: 2048
+ map_module: false
+ warp_module: false
+ style_modulation: false
+
+ bg_encoder_config:
+ target: refnet.modules.unet.ReferenceNet
+ params:
+ use_checkpoint: True
+ in_channels: 6
+ model_channels: 320
+ adm_in_channels: 1024
+ num_classes: sequential
+ attention_resolutions: [ 4, 2 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: true
+ use_linear_in_transformer: true
+ disable_cross_attentions: true
+ context_dim: 2048
+ transformer_depth: [ 1, 2, 10 ]
+
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 512
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+
+ cond_stage_config:
+ target: refnet.modules.embedder.HFCLIPVisionModel
+ params:
+ arch: ViT-bigG-14
+
+ img_embedder_config:
+ target: refnet.modules.embedder.WDv14SwinTransformerV2
+
+ control_encoder_config:
+ target: refnet.modules.encoder.MultiScaleAttentionEncoder
+ params:
+ in_ch: 3
+ model_channels: 320
+ ch_mults: [1, 2, 4]
+
+ proj_config:
+ target: refnet.modules.proj.ClusterConcat
+ # target: refnet.modules.proj.RecoveryClusterConcat
+ params:
+ input_dim: 1280
+ c_dim: 1024
+ output_dim: 2048
+ token_length: 196
+ dim_head: 128
+
+ scalar_embedder_config:
+ target: refnet.modules.embedder.TimestepEmbedding
+ params:
+ embed_dim: 256
+
+ lora_config:
+ lora_params: [
+ {
+ label: background,
+ root_module: model.diffusion_model,
+ target_keys: [ attn2.to_q, attn2.to_k, attn2.to_v ],
+ r: 4,
+ }
+ ]
\ No newline at end of file
diff --git a/configs/scheduler_cfgs/ddim.yaml b/configs/scheduler_cfgs/ddim.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fb3d92393b590034df264ebf4e4555c3d5e4492a
--- /dev/null
+++ b/configs/scheduler_cfgs/ddim.yaml
@@ -0,0 +1,10 @@
+beta_start: 0.00085
+beta_end: 0.012
+beta_schedule: "scaled_linear"
+clip_sample: false
+steps_offset: 1
+
+### Zero-SNR params
+#rescale_betas_zero_snr: True
+#timestep_spacing: "trailing"
+timestep_spacing: "leading"
\ No newline at end of file
diff --git a/configs/scheduler_cfgs/dpm.yaml b/configs/scheduler_cfgs/dpm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9410a2c84583c3310378ee0d9a32bda9b97c1f63
--- /dev/null
+++ b/configs/scheduler_cfgs/dpm.yaml
@@ -0,0 +1,8 @@
+beta_start: 0.00085
+beta_end: 0.012
+beta_schedule: "scaled_linear"
+steps_offset: 1
+
+### Zero-SNR params
+#rescale_betas_zero_snr: True
+timestep_spacing: "leading"
\ No newline at end of file
diff --git a/configs/scheduler_cfgs/dpm_sde.yaml b/configs/scheduler_cfgs/dpm_sde.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f383f8d6dba924c49a6a2f0717d3b93926db3071
--- /dev/null
+++ b/configs/scheduler_cfgs/dpm_sde.yaml
@@ -0,0 +1,9 @@
+beta_start: 0.00085
+beta_end: 0.012
+beta_schedule: "scaled_linear"
+steps_offset: 1
+
+### Zero-SNR params
+#rescale_betas_zero_snr: True
+timestep_spacing: "leading"
+algorithm_type: sde-dpmsolver++
\ No newline at end of file
diff --git a/configs/scheduler_cfgs/lms.yaml b/configs/scheduler_cfgs/lms.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26950c26cca70bfb186b161aec77ab85118882b7
--- /dev/null
+++ b/configs/scheduler_cfgs/lms.yaml
@@ -0,0 +1,9 @@
+beta_start: 0.00085
+beta_end: 0.012
+beta_schedule: "scaled_linear"
+#clip_sample: false
+steps_offset: 1
+
+### Zero-SNR params
+#rescale_betas_zero_snr: True
+timestep_spacing: "leading"
\ No newline at end of file
diff --git a/configs/scheduler_cfgs/pndm.yaml b/configs/scheduler_cfgs/pndm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c8342ca35f07d2f66ba9429f7283be86db0b82f7
--- /dev/null
+++ b/configs/scheduler_cfgs/pndm.yaml
@@ -0,0 +1,10 @@
+beta_start: 0.00085
+beta_end: 0.012
+beta_schedule: "scaled_linear"
+#clip_sample: false
+steps_offset: 1
+
+### Zero-SNR params
+#rescale_betas_zero_snr: True
+#timestep_spacing: "trailing"
+timestep_spacing: "leading"
\ No newline at end of file
diff --git a/k_diffusion/__init__.py b/k_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..029a4a219b0259c183d177517a44fc0c0582dfe0
--- /dev/null
+++ b/k_diffusion/__init__.py
@@ -0,0 +1,8 @@
+from .sampling import *
+
+
+def create_noise_sampler(x, sigmas, seed):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
\ No newline at end of file
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b0fc6d317044baeb0a7bedc3da9d46189f538e
--- /dev/null
+++ b/k_diffusion/external.py
@@ -0,0 +1,181 @@
+import math
+
+import torch
+from torch import nn
+
+from . import sampling, utils
+
+
+class VDenoiser(nn.Module):
+ """A v-diffusion-pytorch model wrapper for k-diffusion."""
+
+ def __init__(self, inner_model):
+ super().__init__()
+ self.inner_model = inner_model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def sigma_to_t(self, sigma):
+ return sigma.atan() / math.pi * 2
+
+ def t_to_sigma(self, t):
+ return (t * math.pi / 2).tan()
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ target = (input - c_skip * noised_input) / c_out
+ return (model_output - target).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
+
+
+class DiscreteSchedule(nn.Module):
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
+ levels."""
+
+ def __init__(self, sigmas, quantize):
+ super().__init__()
+ self.register_buffer('sigmas', sigmas)
+ self.register_buffer('log_sigmas', sigmas.log())
+ self.quantize = quantize
+
+ @property
+ def sigma_min(self):
+ return self.sigmas[0]
+
+ @property
+ def sigma_max(self):
+ return self.sigmas[-1]
+
+ def get_sigmas(self, n=None):
+ if n is None:
+ return sampling.append_zero(self.sigmas.flip(0))
+ t_max = len(self.sigmas) - 1
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
+ return sampling.append_zero(self.t_to_sigma(t))
+
+ def sigma_to_t(self, sigma, quantize=None):
+ quantize = self.quantize if quantize is None else quantize
+ log_sigma = sigma.log()
+ dists = log_sigma - self.log_sigmas[:, None]
+ if quantize:
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+ t = (1 - w) * low_idx + w * high_idx
+ return t.view(sigma.shape)
+
+ def t_to_sigma(self, t):
+ t = t.float()
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
+ return log_sigma.exp()
+
+
+class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
+ noise)."""
+
+ def __init__(self, model, alphas_cumprod, quantize):
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
+ self.inner_model = model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_out = -sigma
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_out, c_in
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model(*args, **kwargs)
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return (eps - noise).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return input + eps * c_out
+
+
+class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
+ """A wrapper for OpenAI diffusion models."""
+
+ def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
+ alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
+ super().__init__(model, alphas_cumprod, quantize=quantize)
+ self.has_learned_sigmas = has_learned_sigmas
+
+ def get_eps(self, *args, **kwargs):
+ model_output = self.inner_model(*args, **kwargs)
+ if self.has_learned_sigmas:
+ return model_output.chunk(2, dim=1)[0]
+ return model_output
+
+
+class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ """A wrapper for CompVis diffusion models."""
+
+ def __init__(self, model, quantize=False, device='cpu'):
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
+ self.sigmas = self.sigmas.to(device)
+ self.log_sigmas = self.log_sigmas.to(device)
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+class DiscreteVDDPMDenoiser(DiscreteSchedule):
+ """A wrapper for discrete schedule DDPM models that output v."""
+
+ def __init__(self, model, alphas_cumprod, quantize):
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
+ self.inner_model = model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def get_v(self, *args, **kwargs):
+ return self.inner_model(*args, **kwargs)
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ target = (input - c_skip * noised_input) / c_out
+ return (model_output - target).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
+
+
+class CompVisVDenoiser(DiscreteVDDPMDenoiser):
+ """A wrapper for CompVis diffusion models that output v."""
+
+ def __init__(self, model, quantize=False, device='cpu'):
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
+ self.sigmas = self.sigmas.to(device)
+ self.log_sigmas = self.log_sigmas.to(device)
+
+ def get_v(self, x, t, cond, **kwargs):
+ return self.inner_model.apply_model(x, t, cond)
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f842d3029e1ec65a29c7758de9785e6839cd86fc
--- /dev/null
+++ b/k_diffusion/sampling.py
@@ -0,0 +1,702 @@
+import math
+
+from scipy import integrate
+import torch
+from torch import nn
+from torchdiffeq import odeint
+import torchsde
+from tqdm.auto import trange, tqdm
+
+from . import utils
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ ramp = torch.linspace(0, 1, n).to(device)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return append_zero(sigmas).to(device)
+
+
+def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
+ """Constructs an exponential noise schedule."""
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
+ return append_zero(sigmas)
+
+
+def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
+ """Constructs an polynomial in log sigma noise schedule."""
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
+ return append_zero(sigmas)
+
+
+def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
+ """Constructs a continuous VP noise schedule."""
+ t = torch.linspace(1, eps_s, n, device=device)
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
+ return append_zero(sigmas)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ if not eta:
+ return sigma_to, 0.
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def default_noise_sampler(x):
+ return lambda sigma, sigma_next: torch.randn_like(x)
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get('w0', torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
+ self.batched = True
+ try:
+ assert len(seed) == x.shape[0]
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will
+ use one BrownianTree per batch item, each with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
+
+
+@torch.no_grad()
+def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ # Euler method
+ x = x + d * dt
+ return x
+
+
+@torch.no_grad()
+def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with Euler method steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Euler method
+ dt = sigmas[i + 1] - sigma_hat
+ x = x + d * dt
+ else:
+ # DPM-Solver-2
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with DPM-Solver second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ if sigma_down == 0:
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver-2
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
+ dt_1 = sigma_mid - sigmas[i]
+ dt_2 = sigma_down - sigmas[i]
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+def linear_multistep_coeff(order, t, i, j):
+ if order - 1 > i:
+ raise ValueError(f'Order {order} too high for step {i}')
+ def fn(tau):
+ prod = 1.
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
+
+
+@torch.no_grad()
+def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ ds = []
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ d = to_d(x, sigmas[i], denoised)
+ ds.append(d)
+ if len(ds) > order:
+ ds.pop(0)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ cur_order = min(i + 1, order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+ return x
+
+
+@torch.no_grad()
+def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ v = torch.randint_like(x, 2) * 2 - 1
+ fevals = 0
+ def ode_fn(sigma, x):
+ nonlocal fevals
+ with torch.enable_grad():
+ x = x[0].detach().requires_grad_()
+ denoised = model(x, sigma * s_in, **extra_args)
+ d = to_d(x, sigma, denoised)
+ fevals += 1
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
+ d_ll = (v * grad).flatten(1).sum(1)
+ return d.detach(), d_ll
+ x_min = x, x.new_zeros([x.shape[0]])
+ t = x.new_tensor([sigma_min, sigma_max])
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
+ latent, delta_ll = sol[0][-1], sol[1][-1]
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
+ return ll_prior + delta_ll, {'fevals': fevals}
+
+
+class PIDStepSizeController:
+ """A PID controller for ODE adaptive step size control."""
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
+ self.h = h
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
+ self.b3 = dcoeff / order
+ self.accept_safety = accept_safety
+ self.eps = eps
+ self.errs = []
+
+ def limiter(self, x):
+ return 1 + math.atan(x - 1)
+
+ def propose_step(self, error):
+ inv_error = 1 / (float(error) + self.eps)
+ if not self.errs:
+ self.errs = [inv_error, inv_error, inv_error]
+ self.errs[0] = inv_error
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
+ factor = self.limiter(factor)
+ accept = factor >= self.accept_safety
+ if accept:
+ self.errs[2] = self.errs[1]
+ self.errs[1] = self.errs[0]
+ self.h *= factor
+ return accept
+
+
+class DPMSolver(nn.Module):
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
+
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
+ super().__init__()
+ self.model = model
+ self.extra_args = {} if extra_args is None else extra_args
+ self.eps_callback = eps_callback
+ self.info_callback = info_callback
+
+ def t(self, sigma):
+ return -sigma.log()
+
+ def sigma(self, t):
+ return t.neg().exp()
+
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
+ if key in eps_cache:
+ return eps_cache[key], eps_cache
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
+ if self.eps_callback is not None:
+ self.eps_callback()
+ return eps, {key: eps, **eps_cache}
+
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
+ return x_1, eps_cache
+
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ s1 = t + r1 * h
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
+ return x_2, eps_cache
+
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ s1 = t + r1 * h
+ s2 = t + r2 * h
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
+ return x_3, eps_cache
+
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ if not t_end > t_start and eta:
+ raise ValueError('eta must be 0 for reverse sampling')
+
+ m = math.floor(nfe / 3) + 1
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
+
+ if nfe % 3 == 0:
+ orders = [3] * (m - 2) + [2, 1]
+ else:
+ orders = [3] * (m - 1) + [nfe % 3]
+
+ for i in range(len(orders)):
+ eps_cache = {}
+ t, t_next = ts[i], ts[i + 1]
+ if eta:
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
+ t_next_ = torch.minimum(t_end, self.t(sd))
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
+ else:
+ t_next_, su = t_next, 0.
+
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ denoised = x - self.sigma(t) * eps
+ if self.info_callback is not None:
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
+
+ if orders[i] == 1:
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
+ elif orders[i] == 2:
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
+ else:
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
+
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
+
+ return x
+
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ if order not in {2, 3}:
+ raise ValueError('order should be 2 or 3')
+ forward = t_end > t_start
+ if not forward and eta:
+ raise ValueError('eta must be 0 for reverse sampling')
+ h_init = abs(h_init) * (1 if forward else -1)
+ atol = torch.tensor(atol)
+ rtol = torch.tensor(rtol)
+ s = t_start
+ x_prev = x
+ accept = True
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
+
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
+ eps_cache = {}
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
+ if eta:
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
+ t_ = torch.minimum(t_end, self.t(sd))
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
+ else:
+ t_, su = t, 0.
+
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
+ denoised = x - self.sigma(s) * eps
+
+ if order == 2:
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
+ else:
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
+ accept = pid.propose_step(error)
+ if accept:
+ x_prev = x_low
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
+ s = t
+ info['n_accept'] += 1
+ else:
+ info['n_reject'] += 1
+ info['nfe'] += order
+ info['steps'] += 1
+
+ if self.info_callback is not None:
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
+
+ return x, info
+
+
+@torch.no_grad()
+def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
+ if sigma_min <= 0 or sigma_max <= 0:
+ raise ValueError('sigma_min and sigma_max must not be 0')
+ with tqdm(total=n, disable=disable) as pbar:
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
+ if callback is not None:
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
+
+
+@torch.no_grad()
+def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
+ if sigma_min <= 0 or sigma_max <= 0:
+ raise ValueError('sigma_min and sigma_max must not be 0')
+ with tqdm(disable=disable) as pbar:
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
+ if callback is not None:
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
+ if return_info:
+ return x, info
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigma_down == 0:
+ # Euler method
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver++(2S)
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
+ r = 1 / 2
+ h = t_next - t
+ s = t + r * h
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
+ # Noise addition
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
+ """DPM-Solver++ (stochastic)."""
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Euler method
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigmas[i + 1] - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver++
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ s = t + h * r
+ fac = 1 / (2 * r)
+
+ # Step 1
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
+ s_ = t_fn(sd)
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
+
+ # Step 2
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
+ t_next_ = t_fn(sd)
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """DPM-Solver++(2M)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+ old_denoised = None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ if old_denoised is None or sigmas[i + 1] == 0:
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+ else:
+ h_last = t - t_fn(sigmas[i - 1])
+ r = h_last / h
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+ old_denoised = denoised
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
+ """DPM-Solver++(2M) SDE."""
+
+ if solver_type not in {'heun', 'midpoint'}:
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
+
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+
+ old_denoised = None
+ h_last = None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Denoising step
+ x = denoised
+ else:
+ # DPM-Solver++(2M) SDE
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
+ h = s - t
+ eta_h = eta * h
+
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
+
+ if old_denoised is not None:
+ r = h_last / h
+ if solver_type == 'heun':
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
+ elif solver_type == 'midpoint':
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
+
+ if eta:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
+
+ old_denoised = denoised
+ h_last = h
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """DPM-Solver++(3M) SDE."""
+
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+
+ denoised_1, denoised_2 = None, None
+ h_1, h_2 = None, None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Denoising step
+ x = denoised
+ else:
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
+ h = s - t
+ h_eta = h * (eta + 1)
+
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
+
+ if h_2 is not None:
+ r0 = h_1 / h
+ r1 = h_2 / h
+ d1_0 = (denoised - denoised_1) / r0
+ d1_1 = (denoised_1 - denoised_2) / r1
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
+ d2 = (d1_0 - d1_1) / (r0 + r1)
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
+ phi_3 = phi_2 / h_eta - 0.5
+ x = x + phi_2 * d1 - phi_3 * d2
+ elif h_1 is not None:
+ r = h_1 / h
+ d = (denoised - denoised_1) / r
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
+ x = x + phi_2 * d
+
+ if eta:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
+
+ denoised_1, denoised_2 = denoised, denoised_1
+ h_1, h_2 = h, h_1
+ return x
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..da857b4a8e41af93af9d010542e9246998d95d2b
--- /dev/null
+++ b/k_diffusion/utils.py
@@ -0,0 +1,457 @@
+from contextlib import contextmanager
+import hashlib
+import math
+from pathlib import Path
+import shutil
+import threading
+import urllib
+import warnings
+
+from PIL import Image
+import safetensors
+import torch
+from torch import nn, optim
+from torch.utils import data
+from torchvision.transforms import functional as TF
+
+
+def from_pil_image(x):
+ """Converts from a PIL image to a tensor."""
+ x = TF.to_tensor(x)
+ if x.ndim == 2:
+ x = x[..., None]
+ return x * 2 - 1
+
+
+def to_pil_image(x):
+ """Converts from a tensor to a PIL image."""
+ if x.ndim == 4:
+ assert x.shape[0] == 1
+ x = x[0]
+ if x.shape[0] == 1:
+ x = x[0]
+ return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
+
+
+def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
+ """Apply passed in transforms for HuggingFace Datasets."""
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
+ return {image_key: images}
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def n_params(module):
+ """Returns the number of trainable parameters in a module."""
+ return sum(p.numel() for p in module.parameters())
+
+
+def download_file(path, url, digest=None):
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if not path.exists():
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
+ shutil.copyfileobj(response, f)
+ if digest is not None:
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
+ if digest != file_digest:
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
+ return path
+
+
+@contextmanager
+def train_mode(model, mode=True):
+ """A context manager that places a model into training mode and restores
+ the previous mode on exit."""
+ modes = [module.training for module in model.modules()]
+ try:
+ yield model.train(mode)
+ finally:
+ for i, module in enumerate(model.modules()):
+ module.training = modes[i]
+
+
+def eval_mode(model):
+ """A context manager that places a model into evaluation mode and restores
+ the previous mode on exit."""
+ return train_mode(model, False)
+
+
+@torch.no_grad()
+def ema_update(model, averaged_model, decay):
+ """Incorporates updated model parameters into an exponential moving averaged
+ version of a model. It should be called after each optimizer step."""
+ model_params = dict(model.named_parameters())
+ averaged_params = dict(averaged_model.named_parameters())
+ assert model_params.keys() == averaged_params.keys()
+
+ for name, param in model_params.items():
+ averaged_params[name].lerp_(param, 1 - decay)
+
+ model_buffers = dict(model.named_buffers())
+ averaged_buffers = dict(averaged_model.named_buffers())
+ assert model_buffers.keys() == averaged_buffers.keys()
+
+ for name, buf in model_buffers.items():
+ averaged_buffers[name].copy_(buf)
+
+
+class EMAWarmup:
+ """Implements an EMA warmup using an inverse decay schedule.
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
+ good values for models you plan to train for a million or more steps (reaches decay
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
+ 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 1.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ max_value (float): The maximum EMA decay rate. Default: 1.
+ start_at (int): The epoch to start averaging at. Default: 0.
+ last_epoch (int): The index of last epoch. Default: 0.
+ """
+
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
+ last_epoch=0):
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+ self.start_at = start_at
+ self.last_epoch = last_epoch
+
+ def state_dict(self):
+ """Returns the state of the class as a :class:`dict`."""
+ return dict(self.__dict__.items())
+
+ def load_state_dict(self, state_dict):
+ """Loads the class's state.
+ Args:
+ state_dict (dict): scaler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_value(self):
+ """Gets the current EMA decay rate."""
+ epoch = max(0, self.last_epoch - self.start_at)
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
+
+ def step(self):
+ """Updates the step count."""
+ self.last_epoch += 1
+
+
+class InverseLR(optim.lr_scheduler._LRScheduler):
+ """Implements an inverse decay learning rate schedule with an optional exponential
+ warmup. When last_epoch=-1, sets initial lr as lr.
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
+ (1 / 2)**power of its original value.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
+ power (float): Exponential factor of learning rate decay. Default: 1.
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
+ Default: 0.
+ min_lr (float): The minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ """
+
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
+ last_epoch=-1, verbose=False):
+ self.inv_gamma = inv_gamma
+ self.power = power
+ if not 0. <= warmup < 1:
+ raise ValueError('Invalid value for warmup')
+ self.warmup = warmup
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ return self._get_closed_form_lr()
+
+ def _get_closed_form_lr(self):
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
+ for base_lr in self.base_lrs]
+
+
+class ExponentialLR(optim.lr_scheduler._LRScheduler):
+ """Implements an exponential learning rate schedule with an optional exponential
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
+ continuously by decay (default 0.5) every num_steps steps.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ num_steps (float): The number of steps to decay the learning rate by decay in.
+ decay (float): The factor by which to decay the learning rate every num_steps
+ steps. Default: 0.5.
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
+ Default: 0.
+ min_lr (float): The minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ """
+
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
+ last_epoch=-1, verbose=False):
+ self.num_steps = num_steps
+ self.decay = decay
+ if not 0. <= warmup < 1:
+ raise ValueError('Invalid value for warmup')
+ self.warmup = warmup
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ return self._get_closed_form_lr()
+
+ def _get_closed_form_lr(self):
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
+ for base_lr in self.base_lrs]
+
+
+class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler):
+ """Implements a constant learning rate schedule with an optional exponential
+ warmup. When last_epoch=-1, sets initial lr as lr.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
+ Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ """
+
+ def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False):
+ if not 0. <= warmup < 1:
+ raise ValueError('Invalid value for warmup')
+ self.warmup = warmup
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ return self._get_closed_form_lr()
+
+ def _get_closed_form_lr(self):
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
+ return [warmup * base_lr for base_lr in self.base_lrs]
+
+
+def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
+ """Draws stratified samples from a uniform distribution."""
+ if groups <= 0:
+ raise ValueError(f"groups must be positive, got {groups}")
+ if group < 0 or group >= groups:
+ raise ValueError(f"group must be in [0, {groups})")
+ n = shape[-1] * groups
+ offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
+ u = torch.rand(shape, dtype=dtype, device=device)
+ return (offsets + u) / n
+
+
+stratified_settings = threading.local()
+
+
+@contextmanager
+def enable_stratified(group=0, groups=1, disable=False):
+ """A context manager that enables stratified sampling."""
+ try:
+ stratified_settings.disable = disable
+ stratified_settings.group = group
+ stratified_settings.groups = groups
+ yield
+ finally:
+ del stratified_settings.disable
+ del stratified_settings.group
+ del stratified_settings.groups
+
+
+@contextmanager
+def enable_stratified_accelerate(accelerator, disable=False):
+ """A context manager that enables stratified sampling, distributing the strata across
+ all processes and gradient accumulation steps using settings from Hugging Face Accelerate."""
+ try:
+ rank = accelerator.process_index
+ world_size = accelerator.num_processes
+ acc_steps = accelerator.gradient_state.num_steps
+ acc_step = accelerator.step % acc_steps
+ group = rank * acc_steps + acc_step
+ groups = world_size * acc_steps
+ with enable_stratified(group, groups, disable=disable):
+ yield
+ finally:
+ pass
+
+
+def stratified_with_settings(shape, dtype=None, device=None):
+ """Draws stratified samples from a uniform distribution, using settings from a context
+ manager."""
+ if not hasattr(stratified_settings, 'disable') or stratified_settings.disable:
+ return torch.rand(shape, dtype=dtype, device=device)
+ return stratified_uniform(
+ shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device
+ )
+
+
+def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
+ """Draws samples from an lognormal distribution."""
+ u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7
+ return torch.distributions.Normal(loc, scale).icdf(u).exp()
+
+
+def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
+ """Draws samples from an optionally truncated log-logistic distribution."""
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
+ u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
+
+
+def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
+ """Draws samples from an log-uniform distribution."""
+ min_value = math.log(min_value)
+ max_value = math.log(max_value)
+ return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
+
+
+def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
+ u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
+ return torch.tan(u * math.pi / 2) * sigma_data
+
+
+def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
+ """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
+
+ def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
+ t_min = math.atan(math.exp(-0.5 * logsnr_max))
+ t_max = math.atan(math.exp(-0.5 * logsnr_min))
+ return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
+
+ def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
+ shift = 2 * math.log(noise_d / image_d)
+ return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
+
+ def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
+ logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max)
+ logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max)
+ return torch.lerp(logsnr_low, logsnr_high, t)
+
+ logsnr_min = -2 * math.log(min_value / sigma_data)
+ logsnr_max = -2 * math.log(max_value / sigma_data)
+ u = stratified_with_settings(shape, device=device, dtype=dtype)
+ logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
+ return torch.exp(-logsnr / 2) * sigma_data
+
+
+def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
+ """Draws samples from a split lognormal distribution."""
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
+ u = torch.rand(shape, device=device, dtype=dtype)
+ n_left = n * -scale_1 + loc
+ n_right = n * scale_2 + loc
+ ratio = scale_1 / (scale_1 + scale_2)
+ return torch.where(u < ratio, n_left, n_right).exp()
+
+
+class FolderOfImages(data.Dataset):
+ """Recursively finds all images in a directory. It does not support
+ classes/targets."""
+
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
+
+ def __init__(self, root, transform=None):
+ super().__init__()
+ self.root = Path(root)
+ self.transform = nn.Identity() if transform is None else transform
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
+
+ def __repr__(self):
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, key):
+ path = self.paths[key]
+ with open(path, 'rb') as f:
+ image = Image.open(f).convert('RGB')
+ image = self.transform(image)
+ return image,
+
+
+class CSVLogger:
+ def __init__(self, filename, columns):
+ self.filename = Path(filename)
+ self.columns = columns
+ if self.filename.exists():
+ self.file = open(self.filename, 'a')
+ else:
+ self.file = open(self.filename, 'w')
+ self.write(*self.columns)
+
+ def write(self, *args):
+ print(*args, sep=',', file=self.file, flush=True)
+
+
+@contextmanager
+def tf32_mode(cudnn=None, matmul=None):
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
+ cudnn_old = torch.backends.cudnn.allow_tf32
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
+ try:
+ if cudnn is not None:
+ torch.backends.cudnn.allow_tf32 = cudnn
+ if matmul is not None:
+ torch.backends.cuda.matmul.allow_tf32 = matmul
+ yield
+ finally:
+ if cudnn is not None:
+ torch.backends.cudnn.allow_tf32 = cudnn_old
+ if matmul is not None:
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
+
+
+def get_safetensors_metadata(path):
+ """Retrieves the metadata from a safetensors file."""
+ return safetensors.safe_open(path, "pt").metadata()
+
+
+def ema_update_dict(values, updates, decay):
+ for k, v in updates.items():
+ if k not in values:
+ values[k] = v
+ else:
+ values[k] *= decay
+ values[k] += (1 - decay) * v
+ return values
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1dd07cee6ca2252f61b72e6425d60880366cb2
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,488 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+from typing import Optional, Any
+
+from refnet.util import checkpoint_wrapper, default
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILBLE = True
+ attn_processor = xformers.ops.memory_efficient_attention
+except:
+ XFORMERS_IS_AVAILBLE = False
+ attn_processor = F.scaled_dot_product_attention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ @checkpoint_wrapper
+ def forward(self, x, temb=None):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels, head_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.head_dim = default(head_dim, in_channels)
+ self.heads = in_channels // self.head_dim
+ # if self.head_dim > 256:
+ # self.attn_processor = F.scaled_dot_product_attention
+ # else:
+ self.attn_processor = attn_processor
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, -1, self.heads, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * self.heads, -1, C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = self.attn_processor(q, k, v)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+def make_attn(in_channels, **kwargs):
+ return MemoryEfficientAttnBlock(in_channels)
+
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ checkpoint=True, **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", checkpoint=True, **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..507b82a4fd8d177727c964a4b0217bd6766ab106
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/preprocessor/__init__.py b/preprocessor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f74f96424781101eb41633498af68911c3d877
--- /dev/null
+++ b/preprocessor/__init__.py
@@ -0,0 +1,124 @@
+import os
+
+import torch.hub
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as tf
+import functools
+
+model_path = "preprocessor/weights"
+os.environ["HF_HOME"] = model_path
+torch.hub.set_dir(model_path)
+
+from torch.hub import download_url_to_file
+from transformers import AutoModelForImageSegmentation
+from .anime2sketch import UnetGenerator
+from .manga_line_extractor import res_skip
+from .sketchKeras import SketchKeras
+from .sk_model import LineartDetector
+from .anime_segment import ISNetDIS
+from refnet.util import load_weights
+
+
+class NoneMaskExtractor(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.identity = nn.Identity()
+
+ def proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
+ b, c, h, w = x.shape
+ return torch.zeros([b, 1, h, w], device=x.device)
+
+ def forward(self, x):
+ return self.proceed(x)
+
+
+remote_model_dict = {
+ "lineart": "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth",
+ "lineart_denoise": "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth",
+ "lineart_keras": "https://huggingface.co/tellurion/line_extractor/resolve/main/model.pth",
+ "lineart_sk": "https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth",
+ "ISNet": "https://huggingface.co/tellurion/line_extractor/resolve/main/isnetis.safetensors",
+ "ISNet-sketch": "https://huggingface.co/tellurion/line_extractor/resolve/main/sketch-segment.safetensors"
+}
+
+BiRefNet_dict = {
+ "rmbg-v2": ("briaai/RMBG-2.0", 1024),
+ "BiRefNet": ("ZhengPeng7/BiRefNet", 1024),
+ "BiRefNet_HR": ("ZhengPeng7/BiRefNet_HR", 2048)
+}
+
+def rmbg_proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs):
+ b, c, h, w = x.shape
+ x = (x + 1.0) / 2.
+ x = tf.resize(x, [self.image_size, self.image_size])
+ x = tf.normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ x = self(x)[-1].sigmoid()
+ x = tf.resize(x, [h, w])
+
+ if th and tw:
+ x = tf.pad(x, padding=[(th-h)//2, (tw-w)//2])
+ if dilate:
+ x = F.max_pool2d(x, kernel_size=21, stride=1, padding=10)
+ # x = F.max_pool2d(x, kernel_size=11, stride=1, padding=5)
+ # x = mask_expansion(x, 60, 40)
+ x = torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x))
+ x = x.clamp(0, 1)
+ return x
+
+
+
+def create_model(model_name="lineart"):
+ """Create a model for anime2sketch
+ hardcoding the options for simplicity
+ """
+ if model_name == "none":
+ return NoneMaskExtractor().eval()
+
+ if model_name in BiRefNet_dict.keys():
+ model = AutoModelForImageSegmentation.from_pretrained(
+ BiRefNet_dict[model_name][0],
+ trust_remote_code = True,
+ cache_dir = model_path,
+ device_map = None,
+ low_cpu_mem_usage = False,
+ )
+ model.eval()
+ model.image_size = BiRefNet_dict[model_name][1]
+ model.proceed = rmbg_proceed.__get__(model, model.__class__)
+ return model
+
+ assert model_name in remote_model_dict.keys()
+ remote_path = remote_model_dict[model_name]
+ basename = os.path.basename(remote_path)
+ ckpt_path = os.path.join(model_path, basename)
+
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
+ if not os.path.exists(ckpt_path):
+ cache_path = "preprocessor/weights/weights.tmp"
+ download_url_to_file(remote_path, dst=cache_path)
+ os.rename(cache_path, ckpt_path)
+
+ if model_name == "lineart":
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
+ elif model_name == "lineart_denoise":
+ model = res_skip()
+ elif model_name == "lineart_keras":
+ model = SketchKeras()
+ elif model_name == "lineart_sk":
+ model = LineartDetector()
+ elif model_name == "ISNet" or model_name == "ISNet-sketch":
+ model = ISNetDIS()
+ else:
+ return None
+
+ ckpt = load_weights(ckpt_path)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ model.load_state_dict(ckpt)
+ return model.eval()
\ No newline at end of file
diff --git a/preprocessor/anime2sketch.py b/preprocessor/anime2sketch.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ad7fe1eb1bdf00e8a9aab89df86bf2c0316d3b
--- /dev/null
+++ b/preprocessor/anime2sketch.py
@@ -0,0 +1,119 @@
+import torch
+import torch.nn as nn
+import functools
+import torchvision.transforms as transforms
+
+"""
+ Anime2Sketch: A sketch extractor for illustration, anime art, manga
+ Author: Xiaoyu Zhang
+ Github link: https://github.com/Mukosame/Anime2Sketch
+"""
+
+def to_tensor(x, inverse=False):
+ x = transforms.ToTensor()(x).unsqueeze(0)
+ x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
+ return x if not inverse else -x
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+ def proceed(self, img):
+ sketch = self(to_tensor(img))
+ return -sketch
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x).clamp(-1, 1)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
\ No newline at end of file
diff --git a/preprocessor/anime_segment.py b/preprocessor/anime_segment.py
new file mode 100644
index 0000000000000000000000000000000000000000..89aafadffce5e9e06786d0c32f6e2ddbed25f2fa
--- /dev/null
+++ b/preprocessor/anime_segment.py
@@ -0,0 +1,487 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from refnet.util import default
+
+"""
+ Source code: https://github.com/SkyTNT/anime-segmentation?tab=readme-ov-file
+ Author: SkyTNT
+"""
+
+class REBNCONV(nn.Module):
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
+ super(REBNCONV, self).__init__()
+
+ self.conv_s1 = nn.Conv2d(
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
+ )
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
+ self.relu_s1 = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ hx = x
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+ return xout
+
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src, tar):
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
+
+ return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
+ super(RSU7, self).__init__()
+
+ self.in_ch = in_ch
+ self.mid_ch = mid_ch
+ self.out_ch = out_ch
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+
+ hx = x
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+ hx = self.pool5(hx5)
+
+ hx6 = self.rebnconv6(hx)
+
+ hx7 = self.rebnconv7(hx6)
+
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
+ hx6dup = _upsample_like(hx6d, hx5)
+
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU6, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+
+ hx6 = self.rebnconv6(hx5)
+
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-5 ###
+class RSU5(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU5, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+
+ hx5 = self.rebnconv5(hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-4 ###
+class RSU4(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-4F ###
+class RSU4F(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4F, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
+
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx2 = self.rebnconv2(hx1)
+ hx3 = self.rebnconv3(hx2)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
+
+ return hx1d + hxin
+
+
+class myrebnconv(nn.Module):
+ def __init__(
+ self,
+ in_ch=3,
+ out_ch=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ dilation=1,
+ groups=1,
+ ):
+ super(myrebnconv, self).__init__()
+
+ self.conv = nn.Conv2d(
+ in_ch,
+ out_ch,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.rl = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.rl(self.bn(self.conv(x)))
+
+
+class ISNetDIS(nn.Module):
+ def __init__(self, in_ch=3, out_ch=1):
+ super(ISNetDIS, self).__init__()
+
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage1 = RSU7(64, 32, 64)
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage2 = RSU6(64, 32, 128)
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage3 = RSU5(128, 64, 256)
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage4 = RSU4(256, 128, 512)
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage5 = RSU4F(512, 256, 512)
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage6 = RSU4F(512, 256, 512)
+
+ # decoder
+ self.stage5d = RSU4F(1024, 256, 512)
+ self.stage4d = RSU4(1024, 128, 256)
+ self.stage3d = RSU5(512, 64, 128)
+ self.stage2d = RSU6(256, 32, 64)
+ self.stage1d = RSU7(128, 16, 64)
+
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.conv_in(hx)
+ hx = self.pool_in(hxin)
+
+ # stage 1
+ hx1 = self.stage1(hxin)
+ hx = self.pool12(hx1)
+
+ # stage 2
+ hx2 = self.stage2(hx)
+ hx = self.pool23(hx2)
+
+ # stage 3
+ hx3 = self.stage3(hx)
+ hx = self.pool34(hx3)
+
+ # stage 4
+ hx4 = self.stage4(hx)
+ hx = self.pool45(hx4)
+
+ # stage 5
+ hx5 = self.stage5(hx)
+ hx = self.pool56(hx5)
+
+ # stage 6
+ hx6 = self.stage6(hx)
+ hx6up = _upsample_like(hx6, hx5)
+
+ # -------------------- decoder --------------------
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+ # side output
+ d1 = self.side1(hx1d)
+ d1 = _upsample_like(d1, x)
+
+ # d2 = self.side2(hx2d)
+ # d2 = _upsample_like(d2, x)
+ #
+ # d3 = self.side3(hx3d)
+ # d3 = _upsample_like(d3, x)
+ #
+ # d4 = self.side4(hx4d)
+ # d4 = _upsample_like(d4, x)
+ #
+ # d5 = self.side5(hx5d)
+ # d5 = _upsample_like(d5, x)
+ #
+ # d6 = self.side6(hx6)
+ # d6 = _upsample_like(d6, x)
+
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
+ #
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
+ # return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
+ return torch.sigmoid(d1)
+
+ def proceed(self, x: torch.Tensor, th=None, tw=None, s=1024, dilate=False, crop=True, *args, **kwargs):
+ b, c, h, w = x.shape
+
+ if crop:
+ th, tw = default(th, h), default(tw, w)
+ scale = s / max(h, w)
+ h, w = int(h * scale), int(w * scale)
+
+ canvas = -torch.ones((b, c, s, s), dtype=x.dtype, device=x.device)
+ ph, pw = (s - h) // 2, (s - w) // 2
+ x = F.interpolate(x, scale_factor=scale, mode="bicubic")
+
+ canvas[:, :, ph: ph+h, pw: pw+w] = x
+
+ canvas = 1 - (canvas + 1.) / 2.
+ mask = self(canvas)[:, :, ph: ph+h, pw: pw+w]
+
+ else:
+ x = F.interpolate(x, size=(s, s), mode="bicubic")
+ mask = self(x)
+
+ mask = F.interpolate(mask, (th, tw), mode="bicubic").clamp(0, 1)
+
+ if dilate:
+ mask = F.max_pool2d(mask, kernel_size=21, stride=1, padding=10)
+ # mask = mask_expansion(mask, 32, 20)
+ return mask
\ No newline at end of file
diff --git a/preprocessor/manga_line_extractor.py b/preprocessor/manga_line_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2d80f3f2217cec5cafe210de67519da4370fdef
--- /dev/null
+++ b/preprocessor/manga_line_extractor.py
@@ -0,0 +1,187 @@
+import torch.nn as nn
+import torchvision.transforms as transforms
+
+
+class _bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class _u_bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_u_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+
+
+class _shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample=1):
+ super(_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters or subsample != 1:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
+ )
+
+ def forward(self, x, y):
+ #print(x.size(), y.size(), self.process)
+ if self.process:
+ y0 = self.model(x)
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
+ return y0 + y
+ else:
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
+ return x + y
+
+class _u_shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample):
+ super(_u_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x, y):
+ if self.process:
+ return self.model(x) + y
+ else:
+ return x + y
+
+
+class basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(basic_block, self).__init__()
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.residual(x1)
+ return self.shortcut(x, x2)
+
+class _u_basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(_u_basic_block, self).__init__()
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ y = self.residual(self.conv1(x))
+ return self.shortcut(x, y)
+
+
+class _residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
+ super(_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ init_subsample = 1
+ if i == repetitions - 1 and not is_first_layer:
+ init_subsample = 2
+ if i == 0:
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class _upsampling_residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions):
+ super(_upsampling_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ l = None
+ if i == 0:
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+class res_skip(nn.Module):
+
+ def __init__(self):
+ super(res_skip, self).__init__()
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) # (input)
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) # (block0)
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) # (block1)
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) # (block2)
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) # (block3)
+
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) # (block4)
+ self.res1 = _shortcut(in_filters=192, nb_filters=192) # (block3, block5, subsample=(1,1))
+
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) # (res1)
+ self.res2 = _shortcut(in_filters=96, nb_filters=96) # (block2, block6, subsample=(1,1))
+
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) # (res2)
+ self.res3 = _shortcut(in_filters=48, nb_filters=48) # (block1, block7, subsample=(1,1))
+
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) # (res3)
+ self.res4 = _shortcut(in_filters=24, nb_filters=24) # (block0,block8, subsample=(1,1))
+
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) # (res4)
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) # (block7)
+
+ def forward(self, x):
+ x0 = self.block0(x)
+ x1 = self.block1(x0)
+ x2 = self.block2(x1)
+ x3 = self.block3(x2)
+ x4 = self.block4(x3)
+
+ x5 = self.block5(x4)
+ res1 = self.res1(x3, x5)
+
+ x6 = self.block6(res1)
+ res2 = self.res2(x2, x6)
+
+ x7 = self.block7(res2)
+ res3 = self.res3(x1, x7)
+
+ x8 = self.block8(res3)
+ res4 = self.res4(x0, x8)
+
+ x9 = self.block9(res4)
+ y = self.conv15(x9)
+
+ return y
+
+ def proceed(self, sketch):
+ sketch = transforms.ToTensor()(sketch).unsqueeze(0)[:, 0] * 255
+ sketch = sketch.unsqueeze(1).cuda()
+ sketch = self(sketch) / 127.5 - 1
+ return -sketch.clamp(-1, 1)
\ No newline at end of file
diff --git a/preprocessor/sk_model.py b/preprocessor/sk_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..11cb01017d19f920973127e613e28d033b15f583
--- /dev/null
+++ b/preprocessor/sk_model.py
@@ -0,0 +1,94 @@
+import torch.nn as nn
+import torchvision.transforms.functional as tf
+
+
+norm_layer = nn.InstanceNorm2d
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_features):
+ super(ResidualBlock, self).__init__()
+
+ conv_block = [ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features),
+ nn.ReLU(inplace=True),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features)
+ ]
+
+ self.conv_block = nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ return x + self.conv_block(x)
+
+
+class Generator(nn.Module):
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
+ super(Generator, self).__init__()
+
+ # Initial convolution block
+ model0 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, 64, 7),
+ norm_layer(64),
+ nn.ReLU(inplace=True) ]
+ self.model0 = nn.Sequential(*model0)
+
+ # Downsampling
+ model1 = []
+ in_features = 64
+ out_features = in_features*2
+ for _ in range(2):
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features*2
+ self.model1 = nn.Sequential(*model1)
+
+ model2 = []
+ # Residual blocks
+ for _ in range(n_residual_blocks):
+ model2 += [ResidualBlock(in_features)]
+ self.model2 = nn.Sequential(*model2)
+
+ # Upsampling
+ model3 = []
+ out_features = in_features//2
+ for _ in range(2):
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features//2
+ self.model3 = nn.Sequential(*model3)
+
+ # Output layer
+ model4 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(64, output_nc, 7)]
+ if sigmoid:
+ model4 += [nn.Sigmoid()]
+
+ self.model4 = nn.Sequential(*model4)
+
+ def forward(self, x, cond=None):
+ out = self.model0(x)
+ out = self.model1(out)
+ out = self.model2(out)
+ out = self.model3(out)
+ out = self.model4(out)
+
+ return out
+
+
+class LineartDetector(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = Generator(3, 1, 3)
+
+ def load_state_dict(self, sd):
+ self.model.load_state_dict(sd)
+
+ def proceed(self, sketch):
+ sketch = tf.pil_to_tensor(sketch).unsqueeze(0).cuda().float()
+ return -self.model(sketch)
\ No newline at end of file
diff --git a/preprocessor/sketchKeras.py b/preprocessor/sketchKeras.py
new file mode 100644
index 0000000000000000000000000000000000000000..e435f7a636c1a080e809fb0a2f3ffb48682afe1b
--- /dev/null
+++ b/preprocessor/sketchKeras.py
@@ -0,0 +1,153 @@
+import cv2
+import numpy as np
+
+import torch
+import torch.nn as nn
+
+
+def postprocess(pred, thresh=0.18):
+ assert thresh <= 1.0 and thresh >= 0.0
+
+ pred = torch.amax(pred, 0)
+ pred[pred < thresh] = 0
+ pred -= 0.5
+ pred *= 2
+ return pred
+
+
+class SketchKeras(nn.Module):
+ def __init__(self):
+ super(SketchKeras, self).__init__()
+
+ self.downblock_1 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(1, 32, kernel_size=3, stride=1),
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+ self.downblock_2 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+ self.downblock_3 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(64, 128, kernel_size=4, stride=2),
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(128, 128, kernel_size=3, stride=1),
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+ self.downblock_4 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(128, 256, kernel_size=4, stride=2),
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(256, 256, kernel_size=3, stride=1),
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+ self.downblock_5 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(256, 512, kernel_size=4, stride=2),
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+ self.downblock_6 = nn.Sequential(
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(512, 512, kernel_size=3, stride=1),
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+
+ self.upblock_1 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bicubic"),
+ nn.ReflectionPad2d((1, 2, 1, 2)),
+ nn.Conv2d(1024, 512, kernel_size=4, stride=1),
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(512, 256, kernel_size=3, stride=1),
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+
+ self.upblock_2 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bicubic"),
+ nn.ReflectionPad2d((1, 2, 1, 2)),
+ nn.Conv2d(512, 256, kernel_size=4, stride=1),
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(256, 128, kernel_size=3, stride=1),
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+
+ self.upblock_3 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bicubic"),
+ nn.ReflectionPad2d((1, 2, 1, 2)),
+ nn.Conv2d(256, 128, kernel_size=4, stride=1),
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1),
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+
+ self.upblock_4 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bicubic"),
+ nn.ReflectionPad2d((1, 2, 1, 2)),
+ nn.Conv2d(128, 64, kernel_size=4, stride=1),
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ nn.ReflectionPad2d((1, 1, 1, 1)),
+ nn.Conv2d(64, 32, kernel_size=3, stride=1),
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
+ nn.ReLU(),
+ )
+
+ self.last_pad = nn.ReflectionPad2d((1, 1, 1, 1))
+ self.last_conv = nn.Conv2d(64, 1, kernel_size=3, stride=1)
+
+ def forward(self, x):
+ d1 = self.downblock_1(x)
+ d2 = self.downblock_2(d1)
+ d3 = self.downblock_3(d2)
+ d4 = self.downblock_4(d3)
+ d5 = self.downblock_5(d4)
+ d6 = self.downblock_6(d5)
+
+ u1 = torch.cat((d5, d6), dim=1)
+ u1 = self.upblock_1(u1)
+ u2 = torch.cat((d4, u1), dim=1)
+ u2 = self.upblock_2(u2)
+ u3 = torch.cat((d3, u2), dim=1)
+ u3 = self.upblock_3(u3)
+ u4 = torch.cat((d2, u3), dim=1)
+ u4 = self.upblock_4(u4)
+ u5 = torch.cat((d1, u4), dim=1)
+
+ out = self.last_conv(self.last_pad(u5))
+
+ return out
+
+ def proceed(self, img):
+ img = np.array(img)
+ blurred = cv2.GaussianBlur(img, (0, 0), 3)
+ img = img.astype(int) - blurred.astype(int)
+ img = img.astype(np.float32) / 127.5
+ img /= np.max(img)
+ img = torch.tensor(img).unsqueeze(0).permute(3, 0, 1, 2).cuda()
+ img = self(img)
+ img = postprocess(img, thresh=0.1).unsqueeze(1)
+ return img
\ No newline at end of file
diff --git a/refnet/__init__.py b/refnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/refnet/ldm/__init__.py b/refnet/ldm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3b5e44e026061c2566a44553e2ce92adf2e60b
--- /dev/null
+++ b/refnet/ldm/__init__.py
@@ -0,0 +1 @@
+from .ddpm import LatentDiffusion
\ No newline at end of file
diff --git a/refnet/ldm/ddpm.py b/refnet/ldm/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..10ac55e062d5d76e53dc8d992febd7d271d63a0f
--- /dev/null
+++ b/refnet/ldm/ddpm.py
@@ -0,0 +1,236 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+from contextlib import contextmanager
+from functools import partial
+
+from refnet.util import default, count_params, instantiate_from_config, exists
+from refnet.ldm.util import make_beta_schedule, extract_into_tensor
+
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DDPM(nn.Module):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(
+ self,
+ unet_config,
+ timesteps = 1000,
+ beta_schedule = "scaled_linear",
+ image_size = 256,
+ channels = 3,
+ linear_start = 1e-4,
+ linear_end = 2e-2,
+ cosine_s = 8e-3,
+ v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ parameterization = "eps", # all assuming fixed variance schedules
+ zero_snr = False,
+ half_precision_dtype = "float16",
+ version = "sdv1",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'"
+ assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default"
+ if zero_snr:
+ assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.'
+
+ self.is_sdxl = (version == "sdxl")
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.img_embedder = None
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.model = DiffusionWrapper(unet_config)
+ count_params(self.model, verbose=True)
+ self.v_posterior = v_posterior
+ self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16
+ self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr)
+
+
+ def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s, zero_snr=zero_snr)
+
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def add_noise(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def normalize_timesteps(self, timesteps):
+ return timesteps
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(
+ self,
+ first_stage_config,
+ cond_stage_config,
+ scale_factor = 1.0,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.scale_factor = scale_factor
+ self.first_stage_model, self.cond_stage_model = map(
+ lambda t: instantiate_from_config(t).eval().requires_grad_(False),
+ (first_stage_config, cond_stage_config)
+ )
+
+ @torch.no_grad()
+ def get_first_stage_encoding(self, x):
+ encoder_posterior = self.first_stage_model.encode(x)
+ z = encoder_posterior.sample() * self.scale_factor
+ return z.to(self.dtype).detach()
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach()
+
+ def apply_model(self, x_noisy, t, cond):
+ return self.model(x_noisy, t, **cond)
+
+ def get_learned_embedding(self, c, *args, **kwargs):
+ wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs))
+ clip_emb = self.cond_stage_model.encode(c, **kwargs).detach()
+ return wd_emb, wd_logits, clip_emb
+
+
+class DiffusionWrapper(nn.Module):
+ def __init__(self, diff_model_config):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+
+ def forward(self, x, t, **cond):
+ for k in cond:
+ if k in ["context", "y", "concat"]:
+ cond[k] = torch.cat(cond[k], 1)
+
+ out = self.diffusion_model(x, t, **cond)
+ return out
\ No newline at end of file
diff --git a/refnet/ldm/openaimodel.py b/refnet/ldm/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a666c99ee5f7c5654ae7713a48e42b8dbec79f
--- /dev/null
+++ b/refnet/ldm/openaimodel.py
@@ -0,0 +1,386 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from refnet.ldm.util import (
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from refnet.util import checkpoint_wrapper
+
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ @checkpoint_wrapper
+ def forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t):
+ return timestep_embedding(t, self.dim)
\ No newline at end of file
diff --git a/refnet/ldm/util.py b/refnet/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..550f84fe6c19389c6a19754c35074134c512d071
--- /dev/null
+++ b/refnet/ldm/util.py
@@ -0,0 +1,289 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ )
+ elif schedule == "scaled_linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
+ # return early
+ return betas_for_alpha_bar(
+ n_timestep,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+
+ if zero_snr:
+ betas = rescale_zero_terminal_snr(betas)
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.amp.autocast("cuda", **ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.to(self.weight.dtype)).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
diff --git a/refnet/modules/__init__.py b/refnet/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bc7b110c367b580747e9b1d0748ced2dd2dcfb9
--- /dev/null
+++ b/refnet/modules/__init__.py
@@ -0,0 +1,34 @@
+from collections import namedtuple
+
+
+def wd_v14_swin2_tagger_config():
+ CustomConfig = namedtuple('CustomConfig', [
+ 'architecture', 'num_classes', 'num_features', 'global_pool', 'model_args', 'pretrained_cfg'
+ ])
+
+ custom_config = CustomConfig(
+ architecture="swinv2_base_window8_256",
+ num_classes=9083,
+ num_features=1024,
+ global_pool="avg",
+ model_args={
+ "act_layer": "gelu",
+ "img_size": 448,
+ "window_size": 14
+ },
+ pretrained_cfg={
+ "custom_load": False,
+ "input_size": [3, 448, 448],
+ "fixed_input_size": False,
+ "interpolation": "bicubic",
+ "crop_pct": 1.0,
+ "crop_mode": "center",
+ "mean": [0.5, 0.5, 0.5],
+ "std": [0.5, 0.5, 0.5],
+ "num_classes": 9083,
+ "pool_size": None,
+ "first_conv": None,
+ "classifier": None
+ }
+ )
+ return custom_config
\ No newline at end of file
diff --git a/refnet/modules/attention.py b/refnet/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2bb322e76743ae100d0d3dd85657b6d1ee2a2a
--- /dev/null
+++ b/refnet/modules/attention.py
@@ -0,0 +1,309 @@
+from calendar import c
+import torch.nn as nn
+
+from einops import rearrange
+from refnet.util import exists, default, checkpoint_wrapper
+from .layers import RMSNorm
+from .attn_utils import *
+
+
+def create_masked_attention_bias(
+ mask: torch.Tensor,
+ threshold: float,
+ num_heads: int,
+ context_len: int
+):
+ b, seq_len, _ = mask.shape
+ half_len = context_len // 2
+
+ if context_len % 8 != 0:
+ padded_context_len = ((context_len + 7) // 8) * 8
+ else:
+ padded_context_len = context_len
+
+ fg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
+ bg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype)
+
+ fg_bias[:, :, half_len:] = -float('inf')
+ bg_bias[:, :, :half_len] = -float('inf')
+ attn_bias = torch.where(mask > threshold, fg_bias, bg_bias)
+ return attn_bias.unsqueeze(1).repeat_interleave(num_heads, dim=1)
+
+class Identity(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+
+# Rotary Positional Embeddings implementation
+class RotaryPositionalEmbeddings(nn.Module):
+ def __init__(self, dim, max_seq_len=1024, theta=10000.0):
+ super().__init__()
+ assert dim % 2 == 0, "Dimension must be divisible by 2"
+ dim = dim // 2
+ self.max_seq_len = max_seq_len
+ freqs = torch.outer(
+ torch.arange(self.max_seq_len),
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
+ )
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ self.register_buffer("freq_h", freqs, persistent=False)
+ self.register_buffer("freq_w", freqs, persistent=False)
+
+ def forward(self, x, grid_size):
+ bs, seq_len, heads = x.shape[:3]
+ h, w = grid_size
+
+ x_complex = torch.view_as_complex(
+ x.float().reshape(bs, seq_len, heads, -1, 2)
+ )
+ freqs = torch.cat([
+ self.freq_h[:h].view(1, h, 1, -1).expand(bs, h, w, -1),
+ self.freq_w[:w].view(1, 1, w, -1).expand(bs, h, w, -1)
+ ], dim=-1).reshape(bs, seq_len, 1, -1)
+
+ x_out = x_complex * freqs
+ x_out = torch.view_as_real(x_out).flatten(3)
+
+ return x_out.type_as(x)
+
+
+class MemoryEfficientAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self,
+ query_dim,
+ context_dim = None,
+ heads = None,
+ dim_head = 64,
+ dropout = 0.0,
+ log = False,
+ causal = False,
+ rope = False,
+ max_seq_len = 1024,
+ qk_norm = False,
+ **kwargs
+ ):
+ super().__init__()
+ if log:
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+
+ heads = heads or query_dim // dim_head
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ self.q_norm = RMSNorm(inner_dim) if qk_norm else Identity()
+ self.k_norm = RMSNorm(inner_dim) if qk_norm else Identity()
+ self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_len=max_seq_len) if rope else Identity()
+ self.attn_ops = causal_ops if causal else {}
+
+ # default setting for split cross-attention
+ self.bg_scale = 1.
+ self.fg_scale = 1.
+ self.merge_scale = 0.
+ self.mask_threshold = 0.05
+
+ @checkpoint_wrapper
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ scale=1.,
+ scale_factor=None,
+ grid_size=None,
+ **kwargs,
+ ):
+ context = default(context, x)
+
+ if exists(mask):
+ out = self.masked_forward(x, context, mask, scale, scale_factor)
+ else:
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ out = self.attn_forward(q, k, v, scale, grid_size)
+
+ return self.to_out(out)
+
+ def attn_forward(self, q, k, v, scale=1., grid_size=None, mask=None):
+ q, k = map(
+ lambda t:
+ self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
+ (self.q_norm(q), self.k_norm(k))
+ )
+ v = rearrange(v, "b n (h c) -> b n h c", h=self.heads)
+ out = attn_processor(q, k, v, attn_mask=mask, **self.attn_ops) * scale
+ out = rearrange(out, "b n h c -> b n (h c)")
+ return out
+
+ def masked_forward(self, x, context, mask, scale=1., scale_factor=None):
+ # split cross-attention function
+ def qkv_forward(x, context):
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ return q, k, v
+
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
+ mask = rearrange(
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
+ "b c h w -> b (h w) c"
+ ).contiguous()
+
+ if self.merge_scale > 0:
+ # split cross-attention with merging scale, need two times forward
+ c1, c2 = context.chunk(2, dim=1)
+
+ # Background region cross-attention
+ q2, k2, v2 = qkv_forward(x, c2)
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
+
+ # Foreground region cross-attention
+ q1, k1, v1 = qkv_forward(x, c1)
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
+
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
+ return torch.where(mask < self.mask_threshold, bg_out, fg_out)
+
+ else:
+ attn_mask = create_masked_attention_bias(
+ mask, self.mask_threshold, self.heads, context.size(1)
+ )
+ q, k, v = qkv_forward(x, context)
+ return self.attn_forward(q, k, v, mask=attn_mask) * scale
+
+
+class MultiModalAttention(MemoryEfficientAttention):
+ def __init__(self, query_dim, context_dim_2, heads=8, dim_head=64, qk_norm=False, *args, **kwargs):
+ super().__init__(query_dim, heads=heads, dim_head=dim_head, qk_norm=qk_norm, *args, **kwargs)
+ inner_dim = dim_head * heads
+ self.to_k_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
+ self.to_v_2 = nn.Linear(context_dim_2, inner_dim, bias=False)
+ self.k2_norm = RMSNorm(inner_dim) if qk_norm else Identity()
+
+ def forward(self, x, context=None, mask=None, scale=1., grid_size=None):
+ if not isinstance(scale, list) and not isinstance(scale, tuple):
+ scale = (scale, scale)
+ assert len(context.shape) == 4, "Multi-modal attention requires different context inputs to be (b, m, n c)"
+ context, context2 = context.chunk(2, dim=1)
+
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k2 = self.to_k_2(context2)
+ v2 = self.to_k_2(context2)
+
+ b, _, _ = q.shape
+ q, k, k2 = map(
+ lambda t: self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size),
+ (self.q_norm(q), self.k_norm(k), self.k2_norm(k2))
+ )
+ v, v2 = map(lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), (v, v2))
+
+ out = (attn_processor(q, k, v, **self.attn_ops) * scale[0] +
+ attn_processor(q, k2, v2, **self.attn_ops) * scale[1])
+
+ if exists(mask):
+ raise NotImplementedError
+ out = rearrange(out, "b n h c -> b n (h c)")
+ return self.to_out(out)
+
+
+class MultiScaleCausalAttention(MemoryEfficientAttention):
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ scale=1.,
+ scale_factor=None,
+ grid_size=None,
+ token_lens=None
+ ):
+ context = default(context, x)
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ out = self.attn_forward(q, k, v, scale, grid_size=grid_size, token_lens=token_lens)
+ return self.to_out(out)
+
+ def attn_forward(self, q, k, v, scale = 1., grid_size = None, token_lens = None):
+ q, k, v = map(
+ lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads),
+ (self.q_norm(q), self.k_norm(k), v)
+ )
+
+ attn_output = []
+ prev_idx = 0
+ for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
+ end_idx = prev_idx + length + (idx == 0)
+ rope_prev_idx = prev_idx + (idx == 0)
+ rope_slice = slice(rope_prev_idx, end_idx)
+
+ q[:, rope_slice] = self.rope(q[:, rope_slice], grid)
+ k[:, rope_slice] = self.rope(k[:, rope_slice], grid)
+ qs = q[:, prev_idx: end_idx]
+ ks, vs = map(lambda t: t[:, :end_idx], (k, v))
+
+ attn_output.append(attn_processor(qs.clone(), ks.clone(), vs.clone()) * scale)
+ prev_idx = end_idx
+ attn_output = rearrange(torch.cat(attn_output, 1), "b n h c -> b n (h c)")
+ return attn_output
+
+ # if FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_AVAILABLE:
+ # k_chunks = []
+ # v_chunks = []
+ # kv_token_lens = []
+ # prev_idx = 0
+ # for idx, (grid, length) in enumerate(zip(grid_size, token_lens)):
+ # end_idx = prev_idx + length + (idx == 0)
+ # rope_prev_idx = prev_idx + (idx == 0)
+
+ # rope_slice = slice(rope_prev_idx, end_idx)
+ # q[:, rope_slice], k[:, rope_slice], v[:, rope_slice] = map(
+ # lambda t: self.rope(t[:, rope_slice], grid),
+ # (q, k, v)
+ # )
+ # kv_token_lens.append(end_idx+1)
+ # k_chunks.append(k[:, :end_idx])
+ # v_chunks.append(v[:, :end_idx])
+ # prev_idx = end_idx
+ # k = torch.cat(k_chunks, 1)
+ # v = torch.cat(v_chunks, 1)
+ # B, N, H, C = q.shape
+ # token_lens = torch.tensor(token_lens, device=q.device, dtype=torch.int32)
+ # kv_token_lens = torch.tensor(kv_token_lens, device=q.device, dtype=torch.int32)
+ # token_lens[0] = token_lens[0] + 1
+ #
+ # cu_seqlens_q, cu_seqlens_kv = map(lambda t:
+ # torch.cat([t.new_zeros([1]), t]).cumsum(0, dtype=torch.int32),
+ # (token_lens, kv_token_lens)
+ # )
+ # max_seqlen_q, max_seqlen_kv = map(lambda t: int(t.max()), (token_lens, kv_token_lens))
+ #
+ # q_flat = q.reshape(-1, H, C).contiguous()
+ # k_flat = k.reshape(-1, H, C).contiguous()
+ # v_flat = v.reshape(-1, H, C).contiguous()
+ # out_flat = flash_attn_varlen_func(
+ # q=q_flat, k=k_flat, v=v_flat,
+ # cu_seqlens_q=cu_seqlens_q,
+ # cu_seqlens_k=cu_seqlens_kv,
+ # max_seqlen_q=max_seqlen_q,
+ # max_seqlen_k=max_seqlen_kv,
+ # causal=True,
+ # )
+ #
+ # out = rearrange(out_flat, "(b n) h c -> b n (h c)", b=B, n=N)
+ # return out * scale
diff --git a/refnet/modules/attn_utils.py b/refnet/modules/attn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c54bef88bad3e0eb0871c04f2c024568a2c60f6
--- /dev/null
+++ b/refnet/modules/attn_utils.py
@@ -0,0 +1,155 @@
+import torch
+import torch.nn.functional as F
+
+ATTN_PRECISION = torch.float16
+
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+ FLASH_ATTN_AVAILABLE = False
+
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+ try:
+ import flash_attn
+ FLASH_ATTN_AVAILABLE = True
+ except ModuleNotFoundError:
+ FLASH_ATTN_AVAILABLE = False
+
+try:
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+def half(x):
+ if x.dtype not in [torch.float16, torch.bfloat16]:
+ x = x.to(ATTN_PRECISION)
+ return x
+
+def attn_processor(q, k, v, attn_mask = None, *args, **kwargs):
+ if attn_mask is not None:
+ if XFORMERS_IS_AVAILBLE:
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=attn_mask, *args, **kwargs
+ )
+ else:
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
+ out = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, *args, **kwargs
+ ).transpose(1, 2)
+ else:
+ if FLASH_ATTN_3_AVAILABLE:
+ dtype = v.dtype
+ q, k, v = map(lambda t: half(t), (q, k, v))
+ out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype)
+ elif FLASH_ATTN_AVAILABLE:
+ dtype = v.dtype
+ q, k, v = map(lambda t: half(t), (q, k, v))
+ out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype)
+ elif XFORMERS_IS_AVAILBLE:
+ out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs)
+ else:
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
+ out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2)
+ return out
+
+
+def flash_attn_varlen_func(q, k, v, **kwargs):
+ if FLASH_ATTN_3_AVAILABLE:
+ return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0]
+ else:
+ return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs)
+
+
+def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5):
+ """
+ Split input tensor into foreground and background based on mask, then concatenate them.
+
+ Args:
+ tensor: Input tensor of shape (batch, seq_len, dim)
+ mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len)
+ threshold: Threshold for mask binarization
+
+ Returns:
+ split_tensor: Concatenated tensor with foreground first, then background
+ fg_indices: Indices of foreground elements for restoration
+ bg_indices: Indices of background elements for restoration
+ original_shape: Original tensor shape for restoration
+ """
+ batch_size, seq_len, *dims = tensor.shape
+ device, dtype = tensor.device, tensor.dtype
+
+ # Ensure mask has correct shape and binarize
+ if mask.dim() == 2:
+ mask = mask.unsqueeze(-1)
+ binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len)
+
+ # Store indices for restoration (keep minimal loop for complex indexing)
+ fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)]
+ bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)]
+
+ # Count elements efficiently
+ fg_counts = binary_mask.sum(dim=1)
+ bg_counts = (~binary_mask).sum(dim=1)
+ max_fg_len = fg_counts.max().item()
+ max_bg_len = bg_counts.max().item()
+
+ # Early exit if no elements
+ if max_fg_len == 0 and max_bg_len == 0:
+ return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape
+
+ # Create output tensor
+ split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype)
+
+ # Vectorized approach using gather for better efficiency
+ for b in range(batch_size):
+ if len(fg_indices[b]) > 0:
+ split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]]
+ if len(bg_indices[b]) > 0:
+ split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]]
+
+ return split_tensor, fg_indices, bg_indices, tensor.shape
+
+
+def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list,
+ original_shape: torch.Size):
+ """
+ Restore original tensor from split tensor using stored indices.
+
+ Args:
+ split_tensor: Split tensor from split_tensor_by_mask
+ fg_indices: List of foreground indices for each batch
+ bg_indices: List of background indices for each batch
+ original_shape: Original tensor shape
+
+ Returns:
+ restored_tensor: Restored tensor with original shape and ordering
+ """
+ batch_size, seq_len = original_shape[:2]
+ dims = original_shape[2:]
+ device, dtype = split_tensor.device, split_tensor.dtype
+
+ # Calculate split point efficiently
+ max_fg_len = max((len(fg) for fg in fg_indices), default=0)
+
+ # Initialize restored tensor
+ restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype)
+
+ # Early exit if no elements to restore
+ if split_tensor.shape[1] == 0:
+ return restored_tensor
+
+ # Split tensor parts
+ fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None
+ bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None
+
+ # Restore in single loop with efficient indexing
+ for b in range(batch_size):
+ if fg_part is not None and len(fg_indices[b]) > 0:
+ restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])]
+ if bg_part is not None and len(bg_indices[b]) > 0:
+ restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])]
+
+ return restored_tensor
diff --git a/refnet/modules/embedder.py b/refnet/modules/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c564f428ba7b8698b0c83567e10348f93ff9771
--- /dev/null
+++ b/refnet/modules/embedder.py
@@ -0,0 +1,489 @@
+import os
+import math
+import numpy as np
+
+from tqdm import tqdm
+from einops import rearrange
+from refnet.util import exists, append_dims
+from refnet.sampling import tps_warp
+from refnet.ldm.openaimodel import Timestep, zero_module
+
+import timm
+import torch
+import torch.nn as nn
+import torchvision.transforms
+import torch.nn.functional as F
+
+from huggingface_hub import hf_hub_download
+from torch.utils.checkpoint import checkpoint
+from safetensors.torch import load_file
+from transformers import (
+ T5EncoderModel,
+ T5Tokenizer,
+ CLIPVisionModelWithProjection,
+ CLIPTextModel,
+ CLIPTokenizer,
+)
+
+versions = {
+ "ViT-bigG-14": "laion2b_s39b_b160k",
+ "ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32
+ "ViT-L-14": "laion2b_s32b_b82k",
+ "hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H]
+}
+hf_versions = {
+ "ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
+ "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ "ViT-L-14": "openai/clip-vit-large-patch14",
+}
+cache_dir = os.environ.get("HF_HOME", "./pretrained_models")
+
+
+class WDv14SwinTransformerV2(nn.Module):
+ """
+ WD-v14-tagger
+ Author: Smiling Wolf
+ Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2
+ """
+ negative_logit = -22
+
+ def __init__(
+ self,
+ input_size = 448,
+ antialias = True,
+ layer_idx = 0.,
+ load_tag = False,
+ logit_threshold = None,
+ direct_forward = False,
+ ):
+ """
+
+ Args:
+ input_size: Input image size
+ antialias: Antialias during rescaling
+ layer_idx: Extracted feature layer
+ load_tag: Set it to true if use the embedder for image classification
+ logit_threshold: Filtering specific channels in logits output
+ """
+ from refnet.modules import wd_v14_swin2_tagger_config
+ super().__init__()
+ custom_config = wd_v14_swin2_tagger_config()
+ self.model: nn.Module = timm.create_model(
+ custom_config.architecture,
+ pretrained = False,
+ num_classes = custom_config.num_classes,
+ global_pool = custom_config.global_pool,
+ **custom_config.model_args
+ )
+ self.image_size = input_size
+ self.antialias = antialias
+ self.layer_idx = layer_idx
+ self.load_tag = load_tag
+ self.logit_threshold = logit_threshold
+ self.direct_forward = direct_forward
+
+ self.load_from_pretrained_url(load_tag)
+ self.get_transformer_length()
+ self.model.eval()
+ self.model.requires_grad_(False)
+
+ if self.direct_forward:
+ self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__)
+
+
+ def load_from_pretrained_url(self, load_tag=False):
+ import pandas as pd
+ from torch.hub import download_url_to_file
+ from data.tag_utils import load_labels, color_tag_index, geometry_tag_index
+
+ ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors")
+ if not os.path.exists(ckpt_path):
+ cache_path = os.path.join(cache_dir, "weights.tmp")
+ download_url_to_file(
+ "https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors",
+ dst = cache_path
+ )
+ os.rename(cache_path, ckpt_path)
+
+ if load_tag:
+ csv_path = hf_hub_download(
+ "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
+ "selected_tags.csv",
+ cache_dir = cache_dir
+ # use_auth_token=HF_TOKEN,
+ )
+ tags_df = pd.read_csv(csv_path)
+ sep_tags = load_labels(tags_df)
+
+ self.tag_names = sep_tags[0]
+ self.rating_indexes = sep_tags[1]
+ self.general_indexes = sep_tags[2]
+ self.character_indexes = sep_tags[3]
+
+ self.color_tags = color_tag_index
+ self.expr_tags = geometry_tag_index
+ self.model.load_state_dict(load_file(ckpt_path))
+
+
+ def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85):
+ assert self.load_tag
+ labels = list(zip(self.tag_names, pred[0].astype(float)))
+
+ # First 4 labels are actually ratings: pick one with argmax
+ # ratings_names = [labels[i] for i in self.rating_indexes]
+ # rating = dict(ratings_names)
+
+ # Then we have general tags: pick any where prediction confidence > threshold
+ general_names = [labels[i] for i in self.general_indexes]
+
+ general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh]
+ general_res = dict(general_res)
+
+ # Everything else is characters: pick any where prediction confidence > threshold
+ character_names = [labels[i] for i in self.character_indexes]
+
+ character_res = [x for x in character_names if x[1] > character_thresh]
+ character_res = dict(character_res)
+
+ sorted_general_strings = sorted(
+ general_res.items(),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+
+ sorted_general_res = sorted(
+ general_res.items(),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
+ sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
+
+ # return sorted_general_strings, rating, character_res, general_res
+ return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res
+
+ def get_transformer_length(self):
+ length = 0
+ for stage in self.model.layers:
+ length += len(stage.blocks)
+ self.transformer_length = length
+
+ def transformer_forward(self, x):
+ idx = 0
+ x = self.model.patch_embed(x)
+ for stage in self.model.layers:
+ x = stage.downsample(x)
+ for blk in stage.blocks:
+ if idx == self.transformer_length - self.layer_idx:
+ return x
+ if not torch.jit.is_scripting():
+ x = checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+ idx += 1
+ return x
+
+
+ def forward(self, x, return_logits=False, pooled=True, **kwargs):
+ # x: [b, h, w, 3]
+ if self.direct_forward:
+ x = self.model(x)
+ else:
+ x = self.transformer_forward(x)
+ x = self.model.norm(x)
+
+ # x: [b, 14, 14, 1024]
+ if return_logits:
+ if pooled:
+ logits = self.model.forward_head(x).unsqueeze(1)
+ # x: [b, 1, 1024]
+
+ else:
+ logits = self.model.head.fc(x)
+ # x = F.sigmoid(x)
+ logits = rearrange(logits, "b h w c -> b (h w) c").contiguous()
+ # x: [b, 196, 9083]
+
+ # Need a threshold to cut off unnecessary classes.
+ if exists(self.logit_threshold) and isinstance(self.logit_threshold, float):
+ logits = torch.where(
+ logits > self.logit_threshold,
+ logits,
+ torch.ones_like(logits) * self.negative_logit
+ )
+
+ else:
+ logits = None
+
+ if pooled:
+ x = x.mean(dim=[1, 2]).unsqueeze(1)
+ else:
+ x = rearrange(x, "b h w c -> b (h w) c").contiguous()
+ return [x, logits]
+
+ def preprocess(self, x: torch.Tensor):
+ x = F.interpolate(
+ x,
+ (self.image_size, self.image_size),
+ mode = "bicubic",
+ align_corners = True,
+ antialias = self.antialias
+ )
+ # convert RGB to BGR
+ x = x[:, [2, 1, 0]]
+ return x
+
+ @torch.no_grad()
+ def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs):
+ # Input image must be in RGB format
+ return self(self.preprocess(img), return_logits, pooled)
+
+ @torch.no_grad()
+ def predict_labels(self, img: torch.Tensor, *args, **kwargs):
+ assert len(img.shape) == 4 and img.shape[0] == 1
+ logits = self(self.preprocess(img), return_logits=True, pooled=True)[1]
+ logits = F.sigmoid(logits).detach().cpu().numpy()
+ return self.convert_labels(logits, *args, **kwargs)
+
+ def geometry_update(self, emb, geometry_emb, scale_factor=1):
+ """
+
+ Args:
+ emb: WD embedding from reference image
+ geometry_emb: WD embedding from sketch image
+
+ """
+ geometry_mask = torch.zeros_like(emb)
+ geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels
+ emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor
+ return emb
+
+ @property
+ def dtype(self):
+ return self.model.head.fc.weight.dtype
+
+
+class OpenCLIP(nn.Module):
+ def __init__(self, vision_config=None, text_config=None, **kwargs):
+ super().__init__()
+ if exists(vision_config):
+ vision_config.update(kwargs)
+ else:
+ vision_config = kwargs
+
+ if exists(text_config):
+ text_config.update(kwargs)
+ else:
+ text_config = kwargs
+
+ self.visual = FrozenOpenCLIPImageEmbedder(**vision_config)
+ self.transformer = FrozenOpenCLIPEmbedder(**text_config)
+
+ def preprocess(self, x):
+ return self.visual.preprocess(x)
+
+ @property
+ def scale_factor(self):
+ return self.visual.scale_factor
+
+ def update_scale_factor(self, scale_factor):
+ self.visual.update_scale_factor(scale_factor)
+
+ def encode(self, *args, **kwargs):
+ return self.visual.encode(*args, **kwargs)
+
+ @torch.no_grad()
+ def encode_text(self, text, normalize=True):
+ return self.transformer(text, normalize)
+
+ def calculate_scale(self, v: torch.Tensor, t: torch.Tensor):
+ """
+ Calculate the projection of v along the direction of t
+ params:
+ v: visual tokens from clip image encoder, shape: (b, n, c)
+ t: text features from clip text encoder (argmax -1), shape: (b, 1, c)
+ """
+ return v @ t.mT
+
+
+
+class HFCLIPVisionModel(nn.Module):
+ # TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future
+ def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.):
+ super().__init__()
+ self.model = CLIPVisionModelWithProjection.from_pretrained(
+ hf_versions[arch],
+ cache_dir = cache_dir
+ )
+ self.image_size = image_size
+ self.scale_factor = scale_factor
+ self.register_buffer(
+ 'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False
+ )
+ self.register_buffer(
+ 'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False
+ )
+ self.antialias = True
+ self.requires_grad_(False).eval()
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ ns = int(self.image_size * self.scale_factor)
+ x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias)
+ x = (x + 1.0) / 2.0
+
+ # renormalize according to clip
+ x = (x - self.mean) / self.std
+ return x
+
+ def forward(self, x, output_type):
+ outputs = self.model(x).last_hidden_state
+ if output_type == "cls":
+ outputs = outputs[:, :1]
+ elif output_type == "local":
+ outputs = outputs[:, 1:]
+ outputs = self.model.vision_model.post_layernorm(outputs)
+ outputs = self.model.visual_projection(outputs)
+ return outputs
+
+ @torch.no_grad()
+ def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs):
+ img = self.preprocess(img) if preprocess else img
+
+ if warp_p > 0.:
+ rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim)
+ img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img))
+ return self(img, output_type)
+
+
+
+
+class FrozenT5Embedder(nn.Module):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir)
+ self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ @torch.no_grad()
+ def encode(self, text):
+ return self(text)
+
+
+class HFCLIPTextEmbedder(nn.Module):
+ def __init__(self, arch, freeze=True, device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(
+ hf_versions[arch],
+ cache_dir = cache_dir
+ )
+ self.model = CLIPTextModel.from_pretrained(
+ hf_versions[arch],
+ cache_dir = cache_dir
+ )
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.model = self.model.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ if isinstance(text, torch.Tensor) and text.dtype == torch.long:
+ # Input is already tokenized
+ tokens = text
+ else:
+ # Need to tokenize text input
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+
+ outputs = self.model(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ @torch.no_grad()
+ def encode(self, text, normalize=False):
+ outputs = self(text)
+ if normalize:
+ outputs = outputs / outputs.norm(dim=-1, keepdim=True)
+ return outputs
+
+
+class ScalarEmbedder(nn.Module):
+ """embeds each dimension independently and concatenates them"""
+
+ def __init__(self, embed_dim, out_dim):
+ super().__init__()
+ self.timestep = Timestep(embed_dim)
+ self.embed_layer = nn.Sequential(
+ nn.Linear(embed_dim, out_dim),
+ nn.SiLU(),
+ zero_module(nn.Linear(out_dim, out_features=out_dim))
+ )
+
+ def forward(self, x, dtype=torch.float32):
+ emb = self.timestep(x)
+ emb = rearrange(emb, "b d -> b 1 d")
+ emb = self.embed_layer(emb.to(dtype))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, embed_dim):
+ super().__init__()
+ self.timestep = Timestep(embed_dim)
+
+ def forward(self, x):
+ x = self.timestep(x)
+ return x
+
+
+if __name__ == '__main__':
+ import PIL.Image as Image
+
+ encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H")
+ image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB")
+ image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2
+ image = image.unsqueeze(0)
+ print(image.shape)
+ feat = encoder.encode(image, "local")
+ print(feat.shape)
\ No newline at end of file
diff --git a/refnet/modules/encoder.py b/refnet/modules/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e111c95b8a64034be2c919ea3c6b8fb6cb40ed86
--- /dev/null
+++ b/refnet/modules/encoder.py
@@ -0,0 +1,224 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from refnet.util import checkpoint_wrapper
+from refnet.modules.unet import TimestepEmbedSequential
+from refnet.modules.layers import Upsample, zero_module, RMSNorm, FeedForward
+from refnet.modules.attention import MemoryEfficientAttention, MultiScaleCausalAttention
+from einops import rearrange
+from functools import partial
+
+
+
+def make_zero_conv(in_channels, out_channels=None):
+ out_channels = out_channels or in_channels
+ return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
+
+def activate_zero_conv(in_channels, out_channels=None):
+ out_channels = out_channels or in_channels
+ return TimestepEmbedSequential(
+ nn.SiLU(),
+ zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
+ )
+
+def sequential_downsample(in_channels, out_channels, sequential_cls=nn.Sequential):
+ return sequential_cls(
+ nn.Conv2d(in_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(nn.Conv2d(256, out_channels, 3, padding=1))
+ )
+
+
+class SimpleEncoder(nn.Module):
+ def __init__(self, c_channels, model_channels):
+ super().__init__()
+ self.model = sequential_downsample(c_channels, model_channels)
+
+ def forward(self, x, *args, **kwargs):
+ return self.model(x)
+
+
+class MultiEncoder(nn.Module):
+ def __init__(self, in_ch, model_channels, ch_mults, checkpoint=True, time_embed=False):
+ super().__init__()
+ sequential_cls = TimestepEmbedSequential if time_embed else nn.Sequential
+ output_chs = [model_channels * mult for mult in ch_mults]
+ self.model = sequential_downsample(in_ch, model_channels, sequential_cls)
+ self.zero_layer = make_zero_conv(output_chs[0])
+ self.output_blocks = nn.ModuleList()
+ self.zero_blocks = nn.ModuleList()
+
+ block_num = len(ch_mults)
+ prev_ch = output_chs[0]
+ for i in range(block_num):
+ self.output_blocks.append(sequential_cls(
+ nn.SiLU(),
+ nn.Conv2d(prev_ch, output_chs[i], 3, padding=1, stride=2 if i != block_num-1 else 1),
+ nn.SiLU(),
+ nn.Conv2d(output_chs[i], output_chs[i], 3, padding=1)
+ ))
+ self.zero_blocks.append(
+ TimestepEmbedSequential(make_zero_conv(output_chs[i])) if time_embed
+ else make_zero_conv(output_chs[i])
+ )
+ prev_ch = output_chs[i]
+
+ self.checkpoint = checkpoint
+
+ def forward(self, x):
+ x = self.model(x)
+ hints = [self.zero_layer(x)]
+ for layer, zero_layer in zip(self.output_blocks, self.zero_blocks):
+ x = layer(x)
+ hints.append(zero_layer(x))
+ return hints
+
+
+class MultiScaleAttentionEncoder(nn.Module):
+ def __init__(
+ self,
+ in_ch,
+ model_channels,
+ ch_mults,
+ dim_head = 128,
+ transformer_layers = 2,
+ checkpoint = True
+ ):
+ super().__init__()
+ conv_proj = partial(nn.Conv2d, kernel_size=1, padding=0)
+ output_chs = [model_channels * mult for mult in ch_mults]
+ block_num = len(ch_mults)
+ attn_ch = output_chs[-1]
+
+ self.model = sequential_downsample(in_ch, output_chs[0])
+ self.proj_ins = nn.ModuleList([conv_proj(output_chs[0], attn_ch)])
+ self.proj_outs = nn.ModuleList([zero_module(conv_proj(attn_ch, output_chs[0]))])
+
+ prev_ch = output_chs[0]
+ self.downsample_layers = nn.ModuleList()
+ for i in range(block_num):
+ ch = output_chs[i]
+ self.downsample_layers.append(nn.Sequential(
+ nn.SiLU(),
+ nn.Conv2d(prev_ch, ch, 3, padding=1, stride=2 if i != block_num - 1 else 1),
+ ))
+ self.proj_ins.append(conv_proj(ch, attn_ch))
+ self.proj_outs.append(zero_module(conv_proj(attn_ch, ch)))
+ prev_ch = ch
+
+ self.proj_ins.append(conv_proj(attn_ch, attn_ch))
+ self.attn_layer = MultiScaleCausalAttention(attn_ch, rope=True, qk_norm=True, dim_head=dim_head)
+ # self.transformer = nn.ModuleList([
+ # BasicTransformerBlock(
+ # attn_ch,
+ # rotary_positional_embedding = True,
+ # qk_norm = True,
+ # d_head = dim_head,
+ # disable_cross_attn = True,
+ # self_attn_type = "multi-scale",
+ # ff_mult = 2,
+ # )
+ # ] * transformer_layers)
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ proj_in_iter = iter(self.proj_ins)
+ proj_out_iter = iter(self.proj_outs[::-1])
+
+ x = self.model(x)
+ hints = [rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c")]
+ grid_sizes = [(x.shape[2], x.shape[3])]
+ token_lens = [(x.shape[2] * x.shape[3])]
+
+ for layer in self.downsample_layers:
+ x = layer(x)
+ h, w = x.shape[2], x.shape[3]
+ grid_sizes.append((h, w))
+ token_lens.append(h * w)
+ hints.append(rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c"))
+
+ hints.append(rearrange(
+ next(proj_in_iter)(x.mean(dim=[2, 3], keepdim=True)),
+ "b c h w -> b (h w) c"
+ ))
+
+ hints = hints[::-1]
+ grid_sizes = grid_sizes[::-1]
+ token_lens = token_lens[::-1]
+ hints = torch.cat(hints, 1)
+ hints = self.attn_layer(hints, grid_size=grid_sizes, token_lens=token_lens) + hints
+ # for layer in self.transformer:
+ # hints = layer(hints, grid_size=grid_sizes, token_lens=token_lens)
+
+ prev_idx = 1
+ controls = []
+ for gs, token_len in zip(grid_sizes, token_lens):
+ control = hints[:, prev_idx: prev_idx + token_len]
+ control = rearrange(control, "b (h w) c -> b c h w", h=gs[0], w=gs[1])
+ controls.append(next(proj_out_iter)(control))
+ prev_idx = prev_idx + token_len
+ return controls[::-1]
+
+
+
+class Downsampler(nn.Module):
+ def __init__(self, scale_factor):
+ super().__init__()
+ self.scale_factor = scale_factor
+
+ def forward(self, x):
+ return F.interpolate(x, scale_factor=self.scale_factor, mode="bicubic")
+
+
+class SpatialConditionEncoder(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ dim,
+ out_dim,
+ patch_size,
+ n_layers = 4,
+ ):
+ super().__init__()
+ self.patch_embed = nn.Conv2d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.conv = nn.Sequential(nn.SiLU(), nn.Conv2d(dim, dim, kernel_size=3, padding=1))
+
+ self.transformer = nn.ModuleList(
+ nn.ModuleList([
+ RMSNorm(dim),
+ MemoryEfficientAttention(dim, rope=True),
+ RMSNorm(dim),
+ FeedForward(dim, mult=2)
+ ]) for _ in range(n_layers)
+ )
+ self.out = nn.Sequential(
+ nn.SiLU(),
+ zero_module(nn.Conv2d(dim, out_dim, kernel_size=1, padding=0))
+ )
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ x = self.conv(x)
+
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ for norm, layer, norm2, ff in self.transformer:
+ x = layer(norm(x), grid_size=(h, w)) + x
+ x = ff(norm2(x)) + x
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+
+ return self.out(x)
diff --git a/refnet/modules/layers.py b/refnet/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..68facc4a9e95d5f74b680dbf87a370f3b7eaae3a
--- /dev/null
+++ b/refnet/modules/layers.py
@@ -0,0 +1,99 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from refnet.util import default
+
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ if use_conv:
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
\ No newline at end of file
diff --git a/refnet/modules/lora.py b/refnet/modules/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..076e16f407bfdff559ee545073a9eaeb529ef76e
--- /dev/null
+++ b/refnet/modules/lora.py
@@ -0,0 +1,370 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Union, Dict, List
+from einops import rearrange
+from refnet.util import exists, default
+from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
+
+
+def get_module_safe(self, module_path: str):
+ current_module = self
+ try:
+ for part in module_path.split('.'):
+ current_module = getattr(current_module, part)
+ return current_module
+ except AttributeError:
+ raise AttributeError(f"Cannot find modules {module_path}")
+
+
+def switch_lora(self, v, label=None):
+ for t in [self.to_q, self.to_k, self.to_v]:
+ t.set_lora_active(v, label)
+
+
+def lora_forward(self, x, context, mask, scale=1., scale_factor= None):
+ def qkv_forward(x, context):
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ return q, k, v
+
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
+
+ mask = rearrange(
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
+ "b c h w -> b (h w) c"
+ ).contiguous()
+
+ c1, c2 = context.chunk(2, dim=1)
+
+ # Background region cross-attention
+ if self.use_lora:
+ self.switch_lora(False, "foreground")
+ q2, k2, v2 = qkv_forward(x, c2)
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
+
+ # Character region cross-attention
+ if self.use_lora:
+ self.switch_lora(True, "foreground")
+ q1, k1, v1 = qkv_forward(x, c1)
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
+
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
+ return fg_out * mask + bg_out * (1 - mask)
+ # return torch.where(mask > self.mask_threshold, fg_out, bg_out)
+
+
+def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None):
+ """
+ This function hacks cross-attention layers.
+ Args:
+ x: Query input
+ context: Key and value input
+ mask: Character mask
+ scale: Attention scale
+ sacle_factor: Current latent size factor
+
+ """
+ def qkv_forward(x, context):
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ return q, k, v
+
+ assert exists(scale_factor), "Scale factor must be assigned before masked attention"
+
+ mask = rearrange(
+ F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
+ "b c h w -> b (h w) c"
+ ).contiguous()
+
+ c1, c2 = context.chunk(2, dim=1)
+
+ # Background region cross-attention
+ if self.use_lora:
+ self.switch_lora(True, "background")
+ self.switch_lora(False, "foreground")
+ q2, k2, v2 = qkv_forward(x, c2)
+ bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
+
+ # Foreground region cross-attention
+ if self.use_lora:
+ self.switch_lora(False, "background")
+ self.switch_lora(True, "foreground")
+ q1, k1, v1 = qkv_forward(x, c1)
+ fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
+
+ fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
+ # return fg_out * mask + bg_out * (1 - mask)
+ return torch.where(mask > self.mask_threshold, fg_out, bg_out)
+
+
+
+class MultiLoraInjectedLinear(nn.Linear):
+ """
+ A linear layer that can hold multiple LoRA adapters and merge them.
+ """
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ bias = False,
+ ):
+ super().__init__(in_features, out_features, bias)
+ self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}}
+ self.lora_scales: Dict[str, float] = {}
+ self.active_loras: Dict[str, bool] = {}
+ self.original_weight = None
+ self.original_bias = None
+
+ # Freeze original weights
+ self.weight.requires_grad_(False)
+ if exists(self.bias):
+ self.bias.requires_grad_(False)
+
+ def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0):
+ """Add a new LoRA adapter with the given label."""
+ if isinstance(r, float):
+ r = int(r * self.out_features)
+
+ lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None)
+ lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None)
+ dropout = nn.Dropout(dropout_p)
+
+ # Initialize weights
+ nn.init.normal_(lora_down.weight, std=1 / r)
+ nn.init.zeros_(lora_up.weight)
+
+ self.lora_adapters[label] = {
+ 'down': lora_down,
+ 'up': lora_up,
+ 'dropout': dropout,
+ }
+ self.lora_scales[label] = scale
+ self.active_loras[label] = True
+
+ # Register as submodules
+ self.add_module(f'lora_down_{label}', lora_down)
+ self.add_module(f'lora_up_{label}', lora_up)
+ self.add_module(f'lora_dropout_{label}', dropout)
+
+ def get_trainable_layers(self, label: str = None):
+ """Get trainable layers for specific LoRA or all LoRAs."""
+ layers = []
+ if exists(label):
+ if label in self.lora_adapters:
+ adapter = self.lora_adapters[label]
+ layers.extend([adapter['down'], adapter['up']])
+ else:
+ for adapter in self.lora_adapters.values():
+ layers.extend([adapter['down'], adapter['up']])
+ return layers
+
+ def set_lora_active(self, active: bool, label: str):
+ """Activate or deactivate a specific LoRA adapter."""
+ if label in self.active_loras:
+ self.active_loras[label] = active
+
+ def set_lora_scale(self, scale: float, label: str):
+ """Set the scale for a specific LoRA adapter."""
+ if label in self.lora_scales:
+ self.lora_scales[label] = scale
+
+ def merge_lora_weights(self, labels: List[str] = None):
+ """Merge specified LoRA adapters into the base weights."""
+ if labels is None:
+ labels = list(self.lora_adapters.keys())
+
+ # Store original weights if not already stored
+ if self.original_weight is None:
+ self.original_weight = self.weight.clone()
+ if exists(self.bias):
+ self.original_bias = self.bias.clone()
+
+ merged_weight = self.original_weight.clone()
+ merged_bias = self.original_bias.clone() if exists(self.original_bias) else None
+
+ for label in labels:
+ if label in self.lora_adapters and self.active_loras.get(label, False):
+ lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down']
+ scale = self.lora_scales[label]
+
+ lora_weight = lora_up.weight @ lora_down.weight
+ merged_weight += scale * lora_weight
+
+ if exists(merged_bias) and exists(lora_up.bias):
+ lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias
+ merged_bias += scale * lora_bias
+
+ # Update weights
+ self.weight = nn.Parameter(merged_weight, requires_grad=False)
+ if exists(merged_bias):
+ self.bias = nn.Parameter(merged_bias, requires_grad=False)
+
+ # Deactivate all LoRAs after merging
+ for label in labels:
+ self.active_loras[label] = False
+
+ def recover_original_weight(self):
+ """Recover the original weights before any LoRA modifications."""
+ if self.original_weight is not None:
+ self.weight = nn.Parameter(self.original_weight.clone())
+ if exists(self.original_bias):
+ self.bias = nn.Parameter(self.original_bias.clone())
+
+ # Reactivate all LoRAs
+ for label in self.active_loras:
+ self.active_loras[label] = True
+
+ def forward(self, input):
+ output = super().forward(input)
+
+ # Add contributions from active LoRAs
+ for label, adapter in self.lora_adapters.items():
+ if self.active_loras.get(label, False):
+ lora_out = adapter['up'](adapter['dropout'](adapter['down'](input)))
+ output += self.lora_scales[label] * lora_out
+
+ return output
+
+
+class LoraModules:
+ def __init__(self, sd, lora_params, *args, **kwargs):
+ self.modules = {}
+ self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer
+
+ for cfg in lora_params:
+ root_module = get_module_safe(sd, cfg.pop("root_module"))
+ label = cfg.pop("label", "lora")
+ self.inject_lora(label, root_module, **cfg)
+
+ def inject_lora(
+ self,
+ label,
+ root_module,
+ r,
+ split_forward = False,
+ target_keys = ("to_q", "to_k", "to_v"),
+ filter_keys = None,
+ target_class = None,
+ scale = 1.0,
+ dropout_p = 0.0,
+ ):
+ def check_condition(path, child, class_list):
+ if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys):
+ return False
+ if exists(target_keys) and any(path.endswith(key) for key in target_keys):
+ return True
+ if exists(class_list) and any(
+ isinstance(child, module_class) for module_class in class_list
+ ):
+ return True
+ return False
+
+ def retrieve_target_modules():
+ from refnet.util import get_obj_from_str
+ target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None
+
+ modules = []
+ for name, module in root_module.named_modules():
+ for key, child in module._modules.items():
+ full_path = name + '.' + key if name else key
+ if check_condition(full_path, child, target_class_list):
+ modules.append((module, child, key, full_path))
+ return modules
+
+ modules: list[Union[nn.Module]] = []
+ retrieved_modules = retrieve_target_modules()
+
+ for parent, child, child_name, full_path in retrieved_modules:
+ # Check if this layer already has a MultiLoraInjectedLinear
+ if full_path in self.multi_lora_layers:
+ # Add LoRA to existing MultiLoraInjectedLinear
+ multi_lora_layer = self.multi_lora_layers[full_path]
+ multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
+ else:
+ # Check if the current layer is already a MultiLoraInjectedLinear
+ if isinstance(child, MultiLoraInjectedLinear):
+ child.add_lora_adapter(label, r, scale, dropout_p)
+ self.multi_lora_layers[full_path] = child
+ else:
+ # Replace with MultiLoraInjectedLinear and add first LoRA
+ multi_lora_layer = MultiLoraInjectedLinear(
+ in_features=child.weight.shape[1],
+ out_features=child.weight.shape[0],
+ bias=exists(child.bias),
+ )
+
+ multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
+ parent._modules[child_name] = multi_lora_layer
+ self.multi_lora_layers[full_path] = multi_lora_layer
+
+ if split_forward:
+ parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__)
+ else:
+ parent.masked_forward = lora_forward.__get__(parent, parent.__class__)
+
+ parent.use_lora = True
+ parent.switch_lora = switch_lora.__get__(parent, parent.__class__)
+ modules.append(parent)
+
+ self.modules[label] = modules
+ print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers")
+ return self.multi_lora_layers, modules
+
+ def get_trainable_layers(self, label = None):
+ """Get all trainable layers, optionally filtered by label."""
+ layers = []
+ for lora_layer in self.multi_lora_layers.values():
+ layers += lora_layer.get_trainable_layers(label)
+ return layers
+
+ def switch_lora(self, mode, label = None):
+ if exists(label):
+ for layer in self.multi_lora_layers.values():
+ layer.set_lora_active(mode, label)
+ for module in self.modules[label]:
+ module.use_lora = mode
+ else:
+ for layer in self.multi_lora_layers.values():
+ for lora_label in layer.lora_adapters.keys():
+ layer.set_lora_active(mode, lora_label)
+
+ for modules in self.modules.values():
+ for module in modules:
+ module.use_lora = mode
+
+ def adjust_lora_scales(self, scale, label = None):
+ if exists(label):
+ for layer in self.multi_lora_layers.values():
+ layer.set_lora_scale(scale, label)
+ else:
+ for layer in self.multi_lora_layers.values():
+ for lora_label in layer.lora_adapters.keys():
+ layer.set_lora_scale(scale, lora_label)
+
+ def merge_lora(self, labels = None):
+ if labels is None:
+ labels = list(self.modules.keys())
+ elif isinstance(labels, str):
+ labels = [labels]
+
+ for layer in self.multi_lora_layers.values():
+ layer.merge_lora_weights(labels)
+
+ def recover_lora(self):
+ for layer in self.multi_lora_layers.values():
+ layer.recover_original_weight()
+
+ def get_lora_info(self):
+ """Get information about all LoRA adapters."""
+ info = {}
+ for path, layer in self.multi_lora_layers.items():
+ info[path] = {
+ 'labels': list(layer.lora_adapters.keys()),
+ 'active': {label: active for label, active in layer.active_loras.items()},
+ 'scales': layer.lora_scales.copy()
+ }
+ return info
\ No newline at end of file
diff --git a/refnet/modules/proj.py b/refnet/modules/proj.py
new file mode 100644
index 0000000000000000000000000000000000000000..b619a088f21ef1774a16ba33179fd968bd2e101c
--- /dev/null
+++ b/refnet/modules/proj.py
@@ -0,0 +1,142 @@
+import torch
+import torch.nn as nn
+
+from refnet.modules.layers import zero_module
+from refnet.modules.attention import MemoryEfficientAttention
+from refnet.modules.transformer import BasicTransformerBlock
+from refnet.util import checkpoint_wrapper, exists
+from refnet.util import load_weights
+
+
+class NormalizedLinear(nn.Module):
+ def __init__(self, dim, output_dim, checkpoint=True):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(dim, output_dim),
+ nn.LayerNorm(output_dim)
+ )
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GlobalProjection(nn.Module):
+ def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True):
+ super().__init__()
+ self.c_dim = output_dim
+ self.dim_head = dim_head
+ self.head = (heads[0], heads[0] * heads[1])
+
+ self.proj1 = nn.Linear(input_dim, dim_head * heads[0])
+ self.proj2 = nn.Sequential(
+ nn.SiLU(),
+ zero_module(nn.Linear(dim_head, output_dim * heads[1])),
+ )
+ self.norm = nn.LayerNorm(output_dim)
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous()
+ x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous()
+ return self.norm(x)
+
+
+class ClusterConcat(nn.Module):
+ def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True):
+ super().__init__()
+ self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head)
+ self.norm = nn.LayerNorm(input_dim)
+ self.proj = nn.Sequential(
+ nn.Linear(input_dim + c_dim, output_dim),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim),
+ nn.LayerNorm(output_dim)
+ )
+ self.token_length = token_length
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x, emb, fgbg=False, *args, **kwargs):
+ x = self.attn(x)[:, :self.token_length]
+ x = self.norm(x)
+ x = torch.cat([x, emb], 2)
+ x = self.proj(x)
+
+ if fgbg:
+ x = torch.cat(torch.chunk(x, 2), 1)
+ return x
+
+
+class RecoveryClusterConcat(ClusterConcat):
+ def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs):
+ super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs)
+ self.transformer = BasicTransformerBlock(
+ output_dim, output_dim//dim_head, dim_head,
+ disable_cross_attn=True, checkpoint=False
+ )
+
+ @checkpoint_wrapper
+ def forward(self, x, emb, bg=False):
+ x = self.attn(x)[:, :self.token_length]
+ x = self.norm(x)
+ x = torch.cat([x, emb], 2)
+ x = self.proj(x)
+
+ if bg:
+ x = self.transformer(x)
+ return x
+
+
+class LogitClusterConcat(ClusterConcat):
+ def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs):
+ super().__init__(c_dim=c_dim, *args, **kwargs)
+ self.mlp = AdaptiveMLP(c_dim, mlp_in_dim)
+ if exists(mlp_ckpt_path):
+ self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True)
+
+ @checkpoint_wrapper
+ def forward(self, x, emb, bg=False):
+ with torch.no_grad():
+ emb = self.mlp(emb).detach()
+ return super().forward(x, emb, bg)
+
+
+class AdaptiveMLP(nn.Module):
+ def __init__(self, dim, in_dim, layers=4, checkpoint=True):
+ super().__init__()
+
+ model = [nn.Sequential(nn.Linear(in_dim, dim))]
+ for i in range(1, layers):
+ model += [nn.Sequential(
+ nn.SiLU(),
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim)
+ )]
+ self.mlp = nn.Sequential(*model)
+ self.fusion_layer = nn.Linear(dim * layers, dim, bias=False)
+ self.norm = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x):
+ fx = []
+
+ for layer in self.mlp:
+ x = layer(x)
+ fx.append(x)
+
+ x = torch.cat(fx, dim=2)
+ out = self.fusion_layer(x)
+ out = self.norm(out)
+ return out
+
+
+class Concat(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def forward(self, x, y, *args, **kwargs):
+ return torch.cat([x, y], dim=-1)
\ No newline at end of file
diff --git a/refnet/modules/reference_net.py b/refnet/modules/reference_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..020fcd8337c990e347f0c7627404cb8db6ec7668
--- /dev/null
+++ b/refnet/modules/reference_net.py
@@ -0,0 +1,430 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+from typing import Union
+from functools import partial
+
+from refnet.modules.unet_old import (
+ timestep_embedding,
+ conv_nd,
+ TimestepEmbedSequential,
+ exists,
+ ResBlock,
+ linear,
+ Downsample,
+ zero_module,
+ SelfTransformerBlock,
+ SpatialTransformer,
+)
+from refnet.modules.unet import DualCondUNetXL
+
+
+def hack_inference_forward(model):
+ model.forward = InferenceForward.__get__(model, model.__class__)
+
+
+def hack_unet_forward(unet):
+ unet.original_forward = unet._forward
+ if isinstance(unet, DualCondUNetXL):
+ unet._forward = enhanced_forward_xl.__get__(unet, unet.__class__)
+ else:
+ unet._forward = enhanced_forward.__get__(unet, unet.__class__)
+
+
+def restore_unet_forward(unet):
+ if hasattr(unet, "original_forward"):
+ unet._forward = unet.original_forward.__get__(unet, unet.__class__)
+ del unet.original_forward
+
+
+def modulation(x, scale, shift):
+ return x * (1 + scale) + shift
+
+
+def enhanced_forward(
+ self,
+ x: torch.Tensor,
+ emb: torch.Tensor,
+ hs_fg: torch.Tensor = None,
+ hs_bg: torch.Tensor = None,
+ mask: torch.Tensor = None,
+ threshold: Union[float|torch.Tensor] = None,
+ control: torch.Tensor = None,
+ context: torch.Tensor = None,
+ style_modulations: torch.Tensor = None,
+ **additional_context
+):
+ h = x.to(self.dtype)
+ emb = emb.to(self.dtype)
+ hs = []
+
+ control_iter = iter(control)
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context, mask, **additional_context)
+
+ if idx in self.hint_encoder_index:
+ h += next(control_iter)
+
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for idx, module in enumerate(self.output_blocks):
+ h_skip = hs.pop()
+
+ if exists(mask) and exists(threshold):
+ # inject foreground/background features
+ B, C, H, W = h_skip.shape
+ cm = F.interpolate(mask, (H, W), mode="bicubic")
+ h = torch.cat([h, torch.where(
+ cm > threshold,
+ self.map_modules[idx](h_skip, hs_fg[idx]) if exists(hs_fg) else h_skip,
+ self.warp_modules[idx](h_skip, hs_bg[idx]) if exists(hs_bg) else h_skip
+ )], 1)
+
+ else:
+ h = torch.cat([h, h_skip], 1)
+
+ h = module(h, emb, context, mask, **additional_context)
+
+ if exists(style_modulations):
+ style_norm, emb_proj, style_proj = self.style_modules[idx]
+ style_m = style_modulations[idx] + emb_proj(emb)
+ style_m = style_proj(style_norm(style_m))[...,None,None]
+ scale, shift = style_m.chunk(2, dim=1)
+
+ h = modulation(h, scale, shift)
+
+ return h
+
+def enhanced_forward_xl(
+ self,
+ x: torch.Tensor,
+ emb,
+ z_fg: torch.Tensor = None,
+ z_bg: torch.Tensor = None,
+ hs_fg: torch.Tensor = None,
+ hs_bg: torch.Tensor = None,
+ mask: torch.Tensor = None,
+ inject_mask: torch.Tensor = None,
+ threshold: Union[float|torch.Tensor] = None,
+ concat: torch.Tensor = None,
+ control: torch.Tensor = None,
+ context: torch.Tensor = None,
+ style_modulations: torch.Tensor = None,
+ **additional_context
+):
+ h = x.to(self.dtype)
+ emb = emb.to(self.dtype)
+ hs = []
+ control_iter = iter(control)
+
+ if exists(concat):
+ h = torch.cat([h, concat], 1)
+ h = h + self.concat_conv(h)
+
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context, mask, **additional_context)
+
+ if idx in self.hint_encoder_index:
+ h += next(control_iter)
+
+ if exists(z_fg):
+ h += self.conv_fg(z_fg)
+ z_fg = None
+ if exists(z_bg):
+ h += self.conv_bg(z_bg)
+ z_bg = None
+
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for idx, module in enumerate(self.output_blocks):
+ h_skip = hs.pop()
+
+ if exists(inject_mask) and exists(threshold):
+ # inject foreground/background features
+ B, C, H, W = h_skip.shape
+ cm = F.interpolate(inject_mask, (H, W), mode="bicubic")
+ h = torch.cat([h, torch.where(
+ cm > threshold,
+
+ # foreground injection
+ rearrange(
+ self.map_modules[idx][0](
+ rearrange(h_skip, "b c h w -> b (h w) c"),
+ hs_fg[idx] + self.map_modules[idx][1](emb).unsqueeze(1)
+ ), "b (h w) c -> b c h w", h=H, w=W
+ ) + h_skip if exists(hs_fg) else h_skip,
+
+ # background injection
+ rearrange(
+ self.warp_modules[idx][0](
+ rearrange(h_skip, "b c h w -> b (h w) c"),
+ hs_bg[idx] + self.warp_modules[idx][1](emb).unsqueeze(1)
+ ), "b (h w) c -> b c h w", h=H, w=W
+ ) + h_skip if exists(hs_bg) else h_skip
+ )], 1)
+
+ else:
+ h = torch.cat([h, h_skip], 1)
+
+ h = module(h, emb, context, mask, **additional_context)
+
+ if exists(style_modulations):
+ style_norm, emb_proj, style_proj = self.style_modules[idx]
+ style_m = style_modulations[idx] + emb_proj(emb)
+ style_m = style_proj(style_norm(style_m))[...,None,None]
+ scale, shift = style_m.chunk(2, dim=1)
+
+ h = modulation(h, scale, shift)
+
+ if idx in self.hint_decoder_index:
+ h += next(control_iter)
+
+ return h
+
+def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb).to(self.dtype)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+ emb = emb.to(self.dtype)
+ return self._forward(x, emb, *args, **kwargs)
+
+
+class UNetEncoderXL(nn.Module):
+ transformers = {
+ "vanilla": SpatialTransformer,
+ }
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout = 0,
+ channel_mult = (1, 2, 4, 8),
+ conv_resample = True,
+ dims = 2,
+ num_classes = None,
+ use_checkpoint = False,
+ num_heads = -1,
+ num_head_channels = -1,
+ use_scale_shift_norm = False,
+ resblock_updown = False,
+ use_spatial_transformer = False, # custom transformer support
+ transformer_depth = 1, # custom transformer support
+ context_dim = None, # custom transformer support
+ disable_self_attentions = None,
+ disable_cross_attentions = None,
+ num_attention_blocks = None,
+ use_linear_in_transformer = False,
+ adm_in_channels = None,
+ transformer_type = "vanilla",
+ style_modulation = False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert exists(
+ context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ time_embed_dim = model_channels * 4
+ resblock = partial(
+ ResBlock,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+
+ transformer = partial(
+ self.transformers[transformer_type],
+ context_dim=context_dim,
+ use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=disable_self_attentions,
+ disable_cross_attn=disable_cross_attentions,
+ )
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.style_modulation = style_modulation
+
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+
+ time_embed_dim = model_channels * 4
+ zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
+
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_layers = nn.ModuleList([zero_module(
+ nn.Linear(model_channels, model_channels * 2) if style_modulation else
+ zero_conv(model_channels, model_channels)
+ )])
+
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ num_heads = ch // num_head_channels
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, num_head_channels)
+ if not use_spatial_transformer
+ else transformer(
+ ch, num_heads, num_head_channels, depth=transformer_depth[level]
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_layers.append(zero_module(
+ nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
+ ))
+
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ ) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ ))
+ self.zero_layers.append(zero_module(
+ nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
+ zero_conv(out_ch, out_ch)
+ ))
+ ch = out_ch
+ ds *= 2
+
+
+ def forward(self, x, timesteps = None, y = None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ emb = self.time_embed(t_emb)
+
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+
+ hs = self._forward(x, emb, *args, **kwargs)
+ return hs
+
+ def _forward(self, x, emb, context = None, **additional_context):
+ hints = []
+ h = x.to(self.dtype)
+
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context, **additional_context)
+
+ if self.style_modulation:
+ hint = self.zero_layers[idx](h.mean(dim=[2, 3]))
+ hints.append(hint)
+
+ else:
+ hint = self.zero_layers[idx](h)
+ hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
+ hints.append(hint)
+
+ hints.reverse()
+ return hints
\ No newline at end of file
diff --git a/refnet/modules/transformer.py b/refnet/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6406307020e811617b812fd45bafbc4bcf0d079
--- /dev/null
+++ b/refnet/modules/transformer.py
@@ -0,0 +1,232 @@
+import torch
+import torch.nn as nn
+
+from functools import partial
+from einops import rearrange
+
+from refnet.util import checkpoint_wrapper, exists
+from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm
+from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "vanilla": MemoryEfficientAttention,
+ "multi-scale": MultiScaleCausalAttention,
+ "multi-modal": MultiModalAttention,
+ }
+ def __init__(
+ self,
+ dim,
+ n_heads = None,
+ d_head = 64,
+ dropout = 0.,
+ context_dim = None,
+ gated_ff = True,
+ ff_mult = 4,
+ checkpoint = True,
+ disable_self_attn = False,
+ disable_cross_attn = False,
+ self_attn_type = "vanilla",
+ cross_attn_type = "vanilla",
+ rotary_positional_embedding = False,
+ context_dim_2 = None,
+ casual_self_attn = False,
+ casual_cross_attn = False,
+ qk_norm = False,
+ norm_type = "layer",
+ ):
+ super().__init__()
+ assert self_attn_type in self.ATTENTION_MODES
+ assert cross_attn_type in self.ATTENTION_MODES
+ self_attn_cls = self.ATTENTION_MODES[self_attn_type]
+ crossattn_cls = self.ATTENTION_MODES[cross_attn_type]
+
+ if norm_type == "layer":
+ norm_cls = nn.LayerNorm
+ elif norm_type == "rms":
+ norm_cls = RMSNorm
+ else:
+ raise NotImplementedError(f"Normalization {norm_type} is not implemented.")
+
+ self.dim = dim
+ self.disable_self_attn = disable_self_attn
+ self.disable_cross_attn = disable_cross_attn
+
+ self.attn1 = self_attn_cls(
+ query_dim = dim,
+ heads = n_heads,
+ dim_head = d_head,
+ dropout = dropout,
+ context_dim = context_dim if self.disable_self_attn else None,
+ casual = casual_self_attn,
+ rope = rotary_positional_embedding,
+ qk_norm = qk_norm
+ )
+ self.attn2 = crossattn_cls(
+ query_dim = dim,
+ context_dim = context_dim,
+ context_dim_2 = context_dim_2,
+ heads = n_heads,
+ dim_head = d_head,
+ dropout = dropout,
+ casual = casual_cross_attn
+ ) if not disable_cross_attn else None
+
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult)
+ self.norm1 = norm_cls(dim)
+ self.norm2 = norm_cls(dim) if not disable_cross_attn else None
+ self.norm3 = norm_cls(dim)
+ self.reference_scale = 1
+ self.scale_factor = None
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x, context=None, mask=None, emb=None, **kwargs):
+ x = self.attn1(self.norm1(x), **kwargs) + x
+ if not self.disable_cross_attn:
+ x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SelfInjectedTransformerBlock(BasicTransformerBlock):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.bank = None
+ self.time_proj = None
+ self.injection_type = "concat"
+ self.forward_without_bank = super().forward
+
+ @checkpoint_wrapper
+ def forward(self, x, context=None, mask=None, emb=None, **kwargs):
+ if exists(self.bank):
+ bank = self.bank
+ if bank.shape[0] != x.shape[0]:
+ bank = bank.repeat(x.shape[0], 1, 1)
+ if exists(self.time_proj) and exists(emb):
+ bank = bank + self.time_proj(emb).unsqueeze(1)
+ x_in = self.norm1(x)
+
+ self.attn1.mask_threshold = self.attn2.mask_threshold
+ x = self.attn1(
+ x_in,
+ torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank,
+ mask = mask,
+ scale_factor = self.scale_factor,
+ **kwargs
+ ) + x
+
+ x = self.attn2(
+ self.norm2(x),
+ context,
+ mask = mask,
+ scale = self.reference_scale,
+ scale_factor = self.scale_factor
+ ) + x
+
+ x = self.ff(self.norm3(x)) + x
+ else:
+ x = self.forward_without_bank(x, context, mask, emb)
+ return x
+
+
+class SelfTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head = 64,
+ dropout = 0.,
+ mlp_ratio = 4,
+ checkpoint = True,
+ casual_attn = False,
+ reshape = True
+ ):
+ super().__init__()
+ self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn)
+ self.ff = nn.Sequential(
+ nn.Linear(dim, dim * mlp_ratio),
+ nn.SiLU(),
+ zero_module(nn.Linear(dim * mlp_ratio, dim))
+ )
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.reshape = reshape
+ self.checkpoint = checkpoint
+
+ @checkpoint_wrapper
+ def forward(self, x, context=None):
+ b, c, h, w = x.shape
+ if self.reshape:
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+
+ x = self.attn(self.norm1(x), context if exists(context) else None) + x
+ x = self.ff(self.norm2(x)) + x
+
+ if self.reshape:
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ return x
+
+
+class Transformer(nn.Module):
+ transformer_type = {
+ "vanilla": BasicTransformerBlock,
+ "self-injection": SelfInjectedTransformerBlock,
+ }
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None, use_linear=False,
+ use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs):
+ super().__init__()
+ transformer_block = self.transformer_type[type]
+ if not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ if isinstance(context_dim, list):
+ if depth != len(context_dim):
+ context_dim = depth * [context_dim[0]]
+
+ proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
+ inner_dim = n_heads * d_head
+
+ self.in_channels = in_channels
+ self.proj_in = proj_layer(in_channels, inner_dim)
+ self.transformer_blocks = nn.ModuleList([
+ transformer_block(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout = dropout,
+ context_dim = context_dim[d],
+ checkpoint = use_checkpoint,
+ **(transformer_config or {}),
+ **kwargs
+ ) for d in range(depth)
+ ])
+ self.proj_out = zero_module(proj_layer(inner_dim, in_channels))
+ self.norm = Normalize(in_channels)
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+def SpatialTransformer(*args, **kwargs):
+ return Transformer(type="vanilla", *args, **kwargs)
+
+def SelfInjectTransformer(*args, **kwargs):
+ return Transformer(type="self-injection", *args, **kwargs)
diff --git a/refnet/modules/unet.py b/refnet/modules/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ba95f7c54dd33f7752ccee856a5e754dabc54c
--- /dev/null
+++ b/refnet/modules/unet.py
@@ -0,0 +1,421 @@
+import torch
+import torch.nn as nn
+
+from functools import partial
+from refnet.modules.attention import MemoryEfficientAttention
+from refnet.util import exists
+from refnet.modules.transformer import (
+ SelfTransformerBlock,
+ Transformer,
+ SpatialTransformer,
+ SelfInjectTransformer,
+)
+from refnet.ldm.openaimodel import (
+ timestep_embedding,
+ conv_nd,
+ TimestepBlock,
+ zero_module,
+ ResBlock,
+ linear,
+ Downsample,
+ Upsample,
+ normalization,
+)
+
+
+def hack_inference_forward(model):
+ model.forward = InferenceForward.__get__(model, model.__class__)
+
+
+def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb).to(self.dtype)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(emb.device))
+ emb = emb.to(self.dtype)
+ h = self._forward(x, emb, *args, **kwargs)
+ return self.out(h.to(x.dtype))
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+ # Dispatch constants
+ _D_TIMESTEP = 0
+ _D_TRANSFORMER = 1
+ _D_OTHER = 2
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Cache dispatch types at init (before FSDP wrapping), so forward()
+ # needs no isinstance checks and is immune to FSDP wrapper breakage.
+ self._dispatch = tuple(
+ self._D_TIMESTEP if isinstance(layer, TimestepBlock) else
+ self._D_TRANSFORMER if isinstance(layer, Transformer) else
+ self._D_OTHER
+ for layer in self
+ )
+
+ def forward(self, x, emb=None, context=None, mask=None, **additional_context):
+ for layer, d in zip(self, self._dispatch):
+ if d == self._D_TIMESTEP:
+ x = layer(x, emb)
+ elif d == self._D_TRANSFORMER:
+ x = layer(x, context, mask, emb, **additional_context)
+ else:
+ x = layer(x)
+ return x
+
+
+
+class UNetModel(nn.Module):
+ transformers = {
+ "vanilla": SpatialTransformer,
+ "selfinj": SelfInjectTransformer,
+ }
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ num_res_blocks,
+ attention_resolutions,
+ out_channels = 4,
+ dropout = 0,
+ channel_mult = (1, 2, 4, 8),
+ conv_resample = True,
+ dims = 2,
+ num_classes = None,
+ use_checkpoint = False,
+ num_heads = -1,
+ num_head_channels = -1,
+ use_scale_shift_norm = False,
+ resblock_updown = False,
+ use_spatial_transformer = False, # custom transformer support
+ transformer_depth = 1, # custom transformer support
+ context_dim = None, # custom transformer support
+ disable_self_attentions = None,
+ disable_cross_attentions = False,
+ num_attention_blocks = None,
+ use_linear_in_transformer = False,
+ adm_in_channels = None,
+ transformer_type = "vanilla",
+ map_module = False,
+ warp_module = False,
+ style_modulation = False,
+ discard_final_layers = False, # for reference net
+ additional_transformer_config = None,
+ in_channels_fg = None,
+ in_channels_bg = None,
+ ):
+ super().__init__()
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.num_classes = num_classes
+ self.model_channels = model_channels
+ self.dtype = torch.float32
+
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = transformer_depth[-1]
+ time_embed_dim = model_channels * 4
+ resblock = partial(
+ ResBlock,
+ emb_channels = time_embed_dim,
+ dropout = dropout,
+ dims = dims,
+ use_checkpoint = use_checkpoint,
+ use_scale_shift_norm = use_scale_shift_norm,
+ )
+ transformer = partial(
+ self.transformers[transformer_type],
+ context_dim = context_dim,
+ use_linear = use_linear_in_transformer,
+ use_checkpoint = use_checkpoint,
+ disable_self_attn = disable_self_attentions,
+ disable_cross_attn = disable_cross_attentions,
+ transformer_config = additional_transformer_config
+ )
+
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ])
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [resblock(ch, out_channels=mult * model_channels)]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, current_head_dim)
+ if not use_spatial_transformer
+ else transformer(
+ ch, current_num_heads, current_head_dim,
+ depth=transformer_depth[level],
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(TimestepEmbedSequential(
+ resblock(ch, out_channels=out_ch, down=True) if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ ))
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+ self.middle_block = TimestepEmbedSequential(
+ resblock(ch),
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
+ else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
+ resblock(ch),
+ )
+
+ self.output_blocks = nn.ModuleList([])
+ self.map_modules = nn.ModuleList([])
+ self.warp_modules = nn.ModuleList([])
+ self.style_modules = nn.ModuleList([])
+
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [resblock(ch + ich, out_channels=model_channels * mult)]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
+ else transformer(
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ resblock(ch, up=True) if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ if level == 0 and discard_final_layers:
+ break
+
+ if map_module:
+ self.map_modules.append(nn.ModuleList([
+ MemoryEfficientAttention(
+ ich,
+ heads = ich // num_head_channels,
+ dim_head = num_head_channels
+ ),
+ nn.Linear(time_embed_dim, ich)
+ ]))
+
+ if warp_module:
+ self.warp_modules.append(nn.ModuleList([
+ MemoryEfficientAttention(
+ ich,
+ heads = ich // num_head_channels,
+ dim_head = num_head_channels
+ ),
+ nn.Linear(time_embed_dim, ich)
+ ]))
+
+ # self.warp_modules.append(nn.ModuleList([
+ # SpatialTransformer(ich, ich//num_head_channels, num_head_channels),
+ # nn.Linear(time_embed_dim, ich)
+ # ]))
+
+ if style_modulation:
+ self.style_modules.append(nn.ModuleList([
+ nn.LayerNorm(ch*2),
+ nn.Linear(time_embed_dim, ch*2),
+ zero_module(nn.Linear(ch*2, ch*2))
+ ]))
+
+ if not discard_final_layers:
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ self.conv_fg = zero_module(
+ conv_nd(dims, in_channels_fg, model_channels, 3, padding=1)
+ ) if exists(in_channels_fg) else None
+ self.conv_bg = zero_module(
+ conv_nd(dims, in_channels_bg, model_channels, 3, padding=1)
+ ) if exists(in_channels_bg) else None
+
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ emb = self.time_embed(t_emb)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+
+ h = self._forward(x, emb, *args, **kwargs)
+ return self.out(h).to(x.dtype)
+
+ def _forward(
+ self,
+ x,
+ emb,
+ control = None,
+ context = None,
+ mask = None,
+ **additional_context
+ ):
+ hs = []
+ h = x.to(self.dtype)
+
+ for module in self.input_blocks:
+ h = module(h, emb, context, mask, **additional_context)
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, mask, **additional_context)
+ return h
+
+
+class DualCondUNetXL(UNetModel):
+ def __init__(
+ self,
+ hint_encoder_index = (0, 3, 6, 8),
+ hint_decoder_index = (),
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.hint_encoder_index = hint_encoder_index
+ self.hint_decoder_index = hint_decoder_index
+
+ def _forward(self, x, emb, concat=None, control=None, context=None, mask=None, **additional_context):
+ h = x.to(self.dtype)
+ hs = []
+
+ if exists(concat):
+ h = torch.cat([h, concat], 1)
+
+ control_iter = iter(control)
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context, mask, **additional_context)
+
+ if idx in self.hint_encoder_index:
+ h += next(control_iter)
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for idx, module in enumerate(self.output_blocks):
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, mask, **additional_context)
+
+ if idx in self.hint_decoder_index:
+ h += next(control_iter)
+
+ return h
+
+
+class ReferenceNet(UNetModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(discard_final_layers=True, *args, **kwargs)
+
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ emb = self.time_embed(t_emb)
+
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+ self._forward(x, emb, *args, **kwargs)
+
+ def _forward(self, *args, **kwargs):
+ super()._forward(*args, **kwargs)
+ return None
\ No newline at end of file
diff --git a/refnet/modules/unet_old.py b/refnet/modules/unet_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..1474fb552a9208918a2ad47e6453a620c90aea59
--- /dev/null
+++ b/refnet/modules/unet_old.py
@@ -0,0 +1,596 @@
+import torch
+import torch.nn as nn
+
+from functools import partial
+from refnet.util import exists
+from refnet.modules.transformer import (
+ SelfTransformerBlock,
+ Transformer,
+ SpatialTransformer,
+ rearrange
+)
+from refnet.ldm.openaimodel import (
+ timestep_embedding,
+ conv_nd,
+ TimestepBlock,
+ zero_module,
+ ResBlock,
+ linear,
+ Downsample,
+ Upsample,
+ normalization,
+)
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+def hack_inference_forward(model):
+ model.forward = InferenceForward.__get__(model, model.__class__)
+
+def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb).to(self.dtype)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(emb.device))
+ emb = emb.to(self.dtype)
+ h = self._forward(x, emb, *args, **kwargs)
+ return self.out(h.to(x.dtype))
+
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None, mask=None, **additional_context):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, Transformer):
+ x = layer(x, context, mask, **additional_context)
+ else:
+ x = layer(x)
+ return x
+
+
+
+class UNetModel(nn.Module):
+ transformers = {
+ "vanilla": SpatialTransformer,
+ }
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout = 0,
+ channel_mult = (1, 2, 4, 8),
+ conv_resample = True,
+ dims = 2,
+ num_classes = None,
+ use_checkpoint = False,
+ num_heads = -1,
+ num_head_channels = -1,
+ use_scale_shift_norm = False,
+ resblock_updown = False,
+ use_spatial_transformer = False, # custom transformer support
+ transformer_depth = 1, # custom transformer support
+ context_dim = None, # custom transformer support
+ disable_self_attentions = None,
+ disable_cross_attentions = None,
+ num_attention_blocks = None,
+ use_linear_in_transformer = False,
+ adm_in_channels = None,
+ transformer_type = "vanilla",
+ map_module = False,
+ warp_module = False,
+ style_modulation = False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert exists(context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set'
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.num_classes = num_classes
+ self.model_channels = model_channels
+ self.dtype = torch.float32
+
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = transformer_depth[-1]
+ time_embed_dim = model_channels * 4
+ resblock = partial(
+ ResBlock,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ transformer = partial(
+ self.transformers[transformer_type],
+ context_dim=context_dim,
+ use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=disable_self_attentions,
+ disable_cross_attn=disable_cross_attentions,
+ )
+
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ])
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [resblock(ch, out_channels=mult * model_channels)]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, current_head_dim)
+ if not use_spatial_transformer
+ else transformer(
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(TimestepEmbedSequential(
+ resblock(ch, out_channels=out_ch, down=True) if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ ))
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+ self.middle_block = TimestepEmbedSequential(
+ resblock(ch),
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
+ else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle),
+ resblock(ch),
+ )
+
+ self.output_blocks = nn.ModuleList([])
+ self.map_modules = nn.ModuleList([])
+ self.warp_modules = nn.ModuleList([])
+ self.style_modules = nn.ModuleList([])
+
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [resblock(ch + ich, out_channels=model_channels * mult)]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels > -1:
+ current_num_heads = ch // num_head_channels
+ current_head_dim = num_head_channels
+ else:
+ current_num_heads = num_heads
+ current_head_dim = ch // num_heads
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer
+ else transformer(
+ ch, current_num_heads, current_head_dim, depth=transformer_depth[level]
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ resblock(ch, up=True) if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+ if map_module:
+ self.map_modules.append(
+ SelfTransformerBlock(ich)
+ )
+
+ if warp_module:
+ self.warp_modules.append(
+ SelfTransformerBlock(ich)
+ )
+
+ if style_modulation:
+ self.style_modules.append(nn.ModuleList([
+ nn.LayerNorm(ch*2),
+ nn.Linear(time_embed_dim, ch*2),
+ zero_module(nn.Linear(ch*2, ch*2))
+ ]))
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps=None, y=None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ emb = self.time_embed(t_emb)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+
+ h = self._forward(x, emb, *args, **kwargs)
+ return self.out(h).to(x.dtype)
+
+ def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
+ hs = []
+ h = x.to(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context, mask, **additional_context)
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, mask, **additional_context)
+ return h
+
+
+class DualCondUNet(UNetModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.hint_encoder_index = [0, 3, 6, 9, 11]
+
+ def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context):
+ h = x.to(self.dtype)
+ hs = []
+
+ control_iter = iter(control)
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context, mask, **additional_context)
+
+ if idx in self.hint_encoder_index:
+ h += next(control_iter)
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context, mask, **additional_context)
+
+ for idx, module in enumerate(self.output_blocks):
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, mask, **additional_context)
+
+ return h
+
+class OldUnet(UNetModel):
+ def __init__(self, c_channels, model_channels, channel_mult, *args, **kwargs):
+ super().__init__(channel_mult=channel_mult, model_channels=model_channels, *args, **kwargs)
+ """
+ Semantic condition input blocks, implementation from ControlNet.
+ Paper: Adding Conditional Control to Text-to-Image Diffusion Models
+ Authors: Lvmin Zhang, Anyi Rao, and Maneesh Agrawala
+ Code link: https://github.com/lllyasviel/ControlNet
+ """
+ from refnet.modules.encoder import SimpleEncoder, MultiEncoder
+ # self.semantic_input_blocks = SimpleEncoder(c_channels, model_channels)
+ self.semantic_input_blocks = MultiEncoder(c_channels, model_channels, channel_mult)
+ self.hint_encoder_index = [0, 3, 6, 9, 11]
+
+ def forward(self, x, timesteps=None, control=None, context=None, y=None, **kwargs):
+ concat = control[0].to(self.dtype)
+ context = context.to(self.dtype)
+
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb).to(self.dtype)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.to(self.dtype)
+ hints = self.semantic_input_blocks(concat, emb, context)
+
+ for idx, module in enumerate(self.input_blocks):
+ h = module(h, emb, context)
+ if idx in self.hint_encoder_index:
+ h += hints.pop(0)
+
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context)
+
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.to(x.dtype)
+ return self.out(h)
+
+
+class UNetEncoder(nn.Module):
+ transformers = {
+ "vanilla": SpatialTransformer,
+ }
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout = 0,
+ channel_mult = (1, 2, 4, 8),
+ conv_resample = True,
+ dims = 2,
+ num_classes = None,
+ use_checkpoint = False,
+ num_heads = -1,
+ num_head_channels = -1,
+ use_scale_shift_norm = False,
+ resblock_updown = False,
+ use_spatial_transformer = False, # custom transformer support
+ transformer_depth = 1, # custom transformer support
+ context_dim = None, # custom transformer support
+ disable_self_attentions = None,
+ disable_cross_attentions = None,
+ num_attention_blocks = None,
+ use_linear_in_transformer = False,
+ adm_in_channels = None,
+ transformer_type = "vanilla",
+ style_modulation = False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert exists(
+ context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+ assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.style_modulation = style_modulation
+
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+
+ time_embed_dim = model_channels * 4
+
+ resblock = partial(
+ ResBlock,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+
+ transformer = partial(
+ self.transformers[transformer_type],
+ context_dim=context_dim,
+ use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=disable_self_attentions,
+ disable_cross_attn=disable_cross_attentions,
+ )
+
+ zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0)
+
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_layers = nn.ModuleList([zero_module(
+ nn.Linear(model_channels, model_channels * 2) if style_modulation else
+ zero_conv(model_channels, model_channels)
+ )])
+
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [resblock(ch, out_channels=mult * model_channels)]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ num_heads = ch // num_head_channels
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SelfTransformerBlock(ch, num_head_channels)
+ if not use_spatial_transformer
+ else transformer(
+ ch, num_heads, num_head_channels, depth=transformer_depth[level]
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_layers.append(zero_module(
+ nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch)
+ ))
+
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(TimestepEmbedSequential(
+ resblock(ch, out_channels=mult * model_channels, down=True) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ ))
+ self.zero_layers.append(zero_module(
+ nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else
+ zero_conv(out_ch, out_ch)
+ ))
+ ch = out_ch
+ ds *= 2
+
+ def forward(self, x, timesteps = None, y = None, *args, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ emb = self.time_embed(t_emb)
+
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y.to(self.dtype))
+
+ hs = self._forward(x, emb, *args, **kwargs)
+ return hs
+
+ def _forward(self, x, emb, context = None, **additional_context):
+ hints = []
+ h = x.to(self.dtype)
+
+ for zero_layer, module in zip(self.zero_layers, self.input_blocks):
+ h = module(h, emb, context, **additional_context)
+
+ if self.style_modulation:
+ hint = zero_layer(h.mean(dim=[2, 3]))
+ else:
+ hint = zero_layer(h)
+ hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
+ hints.append(hint)
+
+ hints.reverse()
+ return hints
\ No newline at end of file
diff --git a/refnet/sampling/__init__.py b/refnet/sampling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c00b27a8059ac7a481abae9709eaebde64fd73d
--- /dev/null
+++ b/refnet/sampling/__init__.py
@@ -0,0 +1,11 @@
+from .denoiser import CFGDenoiser, DiffuserDenoiser
+from .hook import UnetHook, torch_dfs
+from .tps_transformation import tps_warp
+from .sampler import KDiffusionSampler, kdiffusion_sampler_list
+from .scheduler import get_noise_schedulers
+
+def get_sampler_list():
+ sampler_list = [
+ "diffuser_" + k for k in DiffuserDenoiser.scheduler_types.keys()
+ ] + kdiffusion_sampler_list()
+ return sorted(sampler_list)
\ No newline at end of file
diff --git a/refnet/sampling/denoiser.py b/refnet/sampling/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7604772c7486e5a22e5ab326d3da6fb2bb0204
--- /dev/null
+++ b/refnet/sampling/denoiser.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+
+import inspect
+import os.path as osp
+from typing import Union, Optional
+from tqdm import tqdm
+from omegaconf import OmegaConf
+from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+)
+
+def exists(v):
+ return v is not None
+
+
+
+class CFGDenoiser(nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, model, device):
+ super().__init__()
+ denoiser = CompVisDenoiser if model.parameterization == "eps" else CompVisVDenoiser
+ self.model_wrap = denoiser(model, device=device)
+
+ @property
+ def inner_model(self):
+ return self.model_wrap
+
+ def forward(
+ self,
+ x,
+ sigma,
+ cond: dict,
+ cond_scale: Union[float, list[float]]
+ ):
+ """
+ Simplify k-diffusion sampler for sketch colorizaiton.
+ Available for reference CFG / sketch CFG or Dual CFG
+ """
+ if not isinstance(cond_scale, list):
+ if cond_scale > 1.:
+ repeats = 2
+ else:
+ return self.inner_model(x, sigma, cond=cond)
+ else:
+ repeats = 3
+
+ x_in = torch.cat([x] * repeats)
+ sigma_in = torch.cat([sigma] * repeats)
+ x_out = self.inner_model(x_in, sigma_in, cond=cond).chunk(repeats)
+
+ if repeats == 2:
+ x_cond, x_uncond = x_out[:]
+ return x_uncond + (x_cond - x_uncond) * cond_scale
+ else:
+ x_cond, x_uncond_0, x_uncond_1 = x_out[:]
+ return (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
+ x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
+
+
+
+
+scheduler_config_path = "configs/scheduler_cfgs"
+class DiffuserDenoiser:
+ scheduler_types = {
+ "ddim": DDIMScheduler,
+ "dpm": DPMSolverMultistepScheduler,
+ "dpm_sde": DPMSolverMultistepScheduler,
+ "pndm": PNDMScheduler,
+ "lms": LMSDiscreteScheduler
+ }
+ def __init__(self, scheduler_type, prediction_type, use_karras=False):
+ scheduler_type = scheduler_type.replace("diffuser_", "")
+ assert scheduler_type in self.scheduler_types.keys(), "Selected scheduler is not implemented"
+ scheduler = self.scheduler_types[scheduler_type]
+ scheduler_config = OmegaConf.load(osp.abspath(osp.join(scheduler_config_path, scheduler_type + ".yaml")))
+ if "use_karras_sigmas" in set(inspect.signature(scheduler).parameters.keys()):
+ scheduler_config.use_karras_sigmas = use_karras
+ self.scheduler = scheduler(prediction_type=prediction_type, **scheduler_config)
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def __call__(
+ self,
+ x,
+ cond,
+ cond_scale,
+ unet,
+ timesteps,
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
+ eta: float = 0.0,
+ device: str = "cuda"
+ ):
+ self.scheduler.set_timesteps(timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ x_start = x
+ x = x * self.scheduler.init_noise_sigma
+ inpaint_latents = cond.pop("inpaint_bg", None)
+
+ if exists(inpaint_latents):
+ mask = cond.get("mask", None)
+ threshold = cond.pop("threshold", 0.5)
+ inpaint_latents = inpaint_latents[0]
+ assert exists(mask)
+ mask = mask[0]
+ mask = torch.where(mask > threshold, torch.ones_like(mask), torch.zeros_like(mask))
+
+ for i, t in enumerate(tqdm(timesteps)):
+ x_t = self.scheduler.scale_model_input(x, t)
+
+ if not isinstance(cond_scale, list):
+ if cond_scale > 1.:
+ repeats = 2
+ else:
+ repeats = 1
+ else:
+ repeats = 3
+
+ x_in = torch.cat([x_t] * repeats)
+ x_out = unet.apply_model(
+ x_in,
+ t[None].expand(x_in.shape[0]),
+ cond=cond
+ )
+
+ if repeats == 1:
+ pred = x_out
+
+ elif repeats == 2:
+ x_cond, x_uncond = x_out.chunk(2)
+ pred = x_uncond + (x_cond - x_uncond) * cond_scale
+
+ else:
+ x_cond, x_uncond_0, x_uncond_1 = x_out.chunk(3)
+ pred = (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] +
+ x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5
+
+ x = self.scheduler.step(
+ pred, t, x, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ if exists(inpaint_latents) and exists(mask) and i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = inpaint_latents
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, x_start, torch.tensor([noise_timestep])
+ )
+ x = (1 - mask) * init_latents_proper + mask * x
+
+ return x
\ No newline at end of file
diff --git a/refnet/sampling/hook.py b/refnet/sampling/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..940240c7973a6e3e955e34fd0a591e5cea8c2995
--- /dev/null
+++ b/refnet/sampling/hook.py
@@ -0,0 +1,257 @@
+import torch
+import torch.nn as nn
+
+from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
+from refnet.util import checkpoint_wrapper
+
+"""
+ This implementation refers to Multi-ControlNet, thanks for the authors
+ Paper: Adding Conditional Control to Text-to-Image Diffusion Models
+ Link: https://github.com/Mikubill/sd-webui-controlnet
+"""
+
+def exists(v):
+ return v is not None
+
+def torch_dfs(model: nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+class AutoMachine():
+ Read = "read"
+ Write = "write"
+
+
+"""
+ This class controls the attentions of reference unet and denoising unet
+"""
+class ReferenceAttentionControl:
+ writer_modules = []
+ reader_modules = []
+ def __init__(
+ self,
+ reader_module,
+ writer_module,
+ time_embed_ch = 0,
+ only_decoder = True,
+ *args,
+ **kwargs
+ ):
+ self.time_embed_ch = time_embed_ch
+ self.trainable_layers = []
+ self.only_decoder = only_decoder
+ self.hooked = False
+
+ self.register("read", reader_module)
+ self.register("write", writer_module)
+
+ if time_embed_ch > 0:
+ self.insert_time_emb_proj(reader_module)
+
+ def insert_time_emb_proj(self, unet):
+ for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
+ if isinstance(module, BasicTransformerBlock):
+ module.time_proj = nn.Linear(self.time_embed_ch, module.dim)
+ self.trainable_layers.append(module.time_proj)
+
+ def register(self, mode, unet):
+ @checkpoint_wrapper
+ def transformer_forward_write(self, x, context=None, mask=None, emb=None, **kwargs):
+ x_in = self.norm1(x)
+ x = self.attn1(x_in) + x
+
+ if not self.disable_cross_attn:
+ x = self.attn2(self.norm2(x), context) + x
+ x = self.ff(self.norm3(x)) + x
+
+ self.bank = x_in
+ return x
+
+ @checkpoint_wrapper
+ def transformer_forward_read(self, x, context=None, mask=None, emb=None, **kwargs):
+ if exists(self.bank):
+ bank = self.bank
+ if bank.shape[0] != x.shape[0]:
+ bank = bank.repeat(x.shape[0], 1, 1)
+ if hasattr(self, "time_proj"):
+ bank = bank + self.time_proj(emb).unsqueeze(1)
+ x_in = self.norm1(x)
+
+ x = self.attn1(
+ x = x_in,
+ context = torch.cat([x_in, bank], 1),
+ mask = mask,
+ scale_factor = self.scale_factor,
+ **kwargs
+ ) + x
+
+ x = self.attn2(
+ x = self.norm2(x),
+ context = context,
+ mask = mask,
+ scale = self.reference_scale,
+ scale_factor = self.scale_factor
+ ) + x
+
+ x = self.ff(self.norm3(x)) + x
+ else:
+ x = self.original_forward(x, context, mask, emb)
+ return x
+
+ assert mode in ["write", "read"]
+
+ if mode == "read":
+ self.hooked = True
+ for module in torch_dfs(unet.output_blocks if self.only_decoder else unet):
+ if isinstance(module, BasicTransformerBlock):
+ if mode == "write":
+ module.original_forward = module.forward
+ module.forward = transformer_forward_write.__get__(module, BasicTransformerBlock)
+ self.writer_modules.append(module)
+ else:
+ if not isinstance(module, SelfInjectedTransformerBlock):
+ print(f"Hooking transformer block {module.__class__.__name__} for read mode")
+ module.original_forward = module.forward
+ module.forward = transformer_forward_read.__get__(module, BasicTransformerBlock)
+ self.reader_modules.append(module)
+
+ def update(self):
+ for idx in range(len(self.writer_modules)):
+ self.reader_modules[idx].bank = self.writer_modules[idx].bank
+
+ def restore(self):
+ for idx in range(len(self.writer_modules)):
+ self.writer_modules[idx].forward = self.writer_modules[idx].original_forward
+ self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
+ self.reader_modules[idx].bank = None
+ self.hooked = False
+
+ def clean(self):
+ for idx in range(len(self.reader_modules)):
+ self.reader_modules[idx].bank = None
+ for idx in range(len(self.writer_modules)):
+ self.writer_modules[idx].bank = None
+ self.hooked = False
+
+ def reader_restore(self):
+ for idx in range(len(self.reader_modules)):
+ self.reader_modules[idx].forward = self.reader_modules[idx].original_forward
+ self.reader_modules[idx].bank = None
+ self.hooked = False
+
+ def get_trainable_layers(self):
+ return self.trainable_layers
+
+
+"""
+ This class is for self-injection inside the denoising unet
+"""
+class UnetHook:
+ def __init__(self):
+ super().__init__()
+ self.attention_auto_machine = AutoMachine.Read
+
+ def enhance_reference(
+ self,
+ model,
+ ldm,
+ bs,
+ s,
+ r,
+ style_cfg=0.5,
+ control_cfg=0,
+ gr_indice=None,
+ injection=False,
+ start_step=0,
+ ):
+ def forward(self, x, t, control, context, **kwargs):
+ if 1 - t[0] / (ldm.num_timesteps - 1) >= outer.start_step:
+ # Write
+ outer.attention_auto_machine = AutoMachine.Write
+
+ rx = ldm.add_noise(outer.r.cpu(), torch.round(t.float()).long().cpu()).cuda().to(x.dtype)
+ self.original_forward(rx, t, control=outer.s, context=context, **kwargs)
+
+ # Read
+ outer.attention_auto_machine = AutoMachine.Read
+ return self.original_forward(x, t, control=control, context=context, **kwargs)
+
+ def hacked_basic_transformer_inner_forward(self, x, context=None, mask=None, emb=None, **kwargs):
+ x_norm1 = self.norm1(x)
+ self_attn1 = None
+ if self.disable_self_attn:
+ # Do not use self-attention
+ self_attn1 = self.attn1(x_norm1, context=context, **kwargs)
+
+ else:
+ # Use self-attention
+ self_attention_context = x_norm1
+ if outer.attention_auto_machine == AutoMachine.Write:
+ self.bank.append(self_attention_context.detach().clone())
+ self.style_cfgs.append(outer.current_style_fidelity)
+ if outer.attention_auto_machine == AutoMachine.Read:
+ if len(self.bank) > 0:
+ style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
+ self_attn1_uc = self.attn1(
+ x_norm1,
+ context=torch.cat([self_attention_context] + self.bank, dim=1),
+ **kwargs
+ )
+ self_attn1_c = self_attn1_uc.clone()
+ if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
+ self_attn1_c[outer.current_uc_indices] = self.attn1(
+ x_norm1[outer.current_uc_indices],
+ context=self_attention_context[outer.current_uc_indices],
+ **kwargs
+ )
+ self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
+ self.bank = []
+ self.style_cfgs = []
+ if self_attn1 is None:
+ self_attn1 = self.attn1(x_norm1, context=self_attention_context)
+
+ x = self_attn1.to(x.dtype) + x
+ x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor, **kwargs) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+ self.s = [s.repeat(bs, 1, 1, 1) * control_cfg for s in ldm.control_encoder(s)]
+ self.r = r
+ self.injection = injection
+ self.start_step = start_step
+ self.current_uc_indices = gr_indice
+ self.current_style_fidelity = style_cfg
+
+ outer = self
+ model = model.diffusion_model
+ model.original_forward = model.forward
+ # TODO: change the class name to target
+ model.forward = forward.__get__(model, model.__class__)
+ all_modules = torch_dfs(model)
+
+ for module in all_modules:
+ if isinstance(module, BasicTransformerBlock):
+ module._unet_hook_original_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.style_cfgs = []
+
+
+ def restore(self, model):
+ model = model.diffusion_model
+ if hasattr(model, "original_forward"):
+ model.forward = model.original_forward
+ del model.original_forward
+
+ all_modules = torch_dfs(model)
+ for module in all_modules:
+ if isinstance(module, BasicTransformerBlock):
+ if hasattr(module, "_unet_hook_original_forward"):
+ module.forward = module._unet_hook_original_forward
+ del module._unet_hook_original_forward
+ if hasattr(module, "bank"):
+ module.bank = None
+ if hasattr(module, "style_cfgs"):
+ del module.style_cfgs
\ No newline at end of file
diff --git a/refnet/sampling/manipulation.py b/refnet/sampling/manipulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e1b64e1bf7bc16da54a313fd615cb92dc765c18
--- /dev/null
+++ b/refnet/sampling/manipulation.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn.functional as F
+
+import numpy as np
+
+
+def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]):
+ """
+ The shape of input scales tensor should be (b, n, 1)
+ """
+ assert len(s.shape) == 3, len(thresholds) == 4
+ maxm = s.max(dim=1, keepdim=True).values
+ minm = s.min(dim=1, keepdim=True).values
+ d = maxm - minm
+
+ maxmin = (s - minm) / d
+
+ adjust_scale = torch.where(maxmin <= thresholds[0],
+ -dscale * ratio,
+ -dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0]))
+ adjust_scale = torch.where(maxmin > thresholds[1],
+ 0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]),
+ adjust_scale)
+ adjust_scale = torch.where(maxmin > thresholds[2],
+ 0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]),
+ adjust_scale)
+ adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale)
+ return adjust_scale
+
+
+def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]):
+ # print(f"target:{t}, anchor:{a}")
+ cls_token = v[:, 0].unsqueeze(1)
+ v = v[:, 1:]
+
+ cur_target_scale = clip.calculate_scale(cls_token, t)
+ # control_scale = clip.calculate_scale(cls_token, c)
+ # print(f"current global target scale: {cur_target_scale},",
+ # f" global control scale: {control_scale}")
+
+ if a is not None and a != "none":
+ a = [a] * v.shape[0]
+ a = clip.encode_text(a)
+ anchor_scale = clip.calculate_scale(cls_token, a)
+ dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale
+ # print(f"global anchor scale: {anchor_scale}")
+
+ c_map = clip.calculate_scale(v, c)
+ a_map = clip.calculate_scale(v, a)
+ pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
+ base = 1 if enhance else 0
+ v = v + (pwm + base * a_map) * (t - a)
+ else:
+ dscale = target_scale - cur_target_scale
+ c_map = clip.calculate_scale(v, c)
+ pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
+ v = v + pwm * t
+ v = torch.cat([cls_token, v], dim=1)
+ return v
+
+def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]):
+ """
+ v: visual tokens in shape (b, n, c)
+ target: target text embeddings in shape (b, 1 ,c)
+ control: control text embeddings in shape (b, 1, c)
+ """
+ controls, targets = clip.encode_text(controls + targets).chunk(2)
+ for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list):
+ v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds)
+ return v
+
+
+def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False):
+ if a is not None and a != "none":
+ a = [a] * v.shape[0]
+ a = clip.encode_text(a)
+ if enhance:
+ s_a = clip.calculate_scale(v, a)
+ v = v - s_a * a
+ else:
+ v = v + target_scale * (t - a)
+ return v
+ if enhance:
+ v = v + target_scale * t
+ else:
+ cur_target_scale = clip.calculate_scale(v, t)
+ v = v + (target_scale - cur_target_scale) * t
+ return v
+
+
+def global_manipulate(clip, v, targets, target_scales, anchors, enhances):
+ targets = clip.encode_text(targets)
+ for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances):
+ v = global_manipulate_step(clip, v, t, s_t, a, enhance)
+ return v
+
+
+def assign_heatmap(s: torch.Tensor, threshold: float):
+ """
+ The shape of input scales tensor should be (b, n, 1)
+ """
+ maxm = s.max(dim=1, keepdim=True).values
+ minm = s.min(dim=1, keepdim=True).values
+ d = maxm - minm
+ return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25)
+
+
+def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3,
+ controls, targets, anchors, thresholds_list, target_scales, enhances):
+ model.low_vram_shift("cond")
+ clip = model.cond_stage_model
+
+ v = clip.encode(reference, "full")
+ if len(targets) > 0:
+ controls, targets = clip.encode_text(controls + targets).chunk(2)
+ inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances)
+ for c, t, a, target_scale, thresholds, enhance in inputs_iter:
+ # update image tokens
+ v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds)
+ token_length = v.shape[1] - 1
+ grid_num = int(token_length ** 0.5)
+ vis_c = clip.encode_text([vis_c])
+ local_v = v[:, 1:]
+ scale = clip.calculate_scale(local_v, vis_c)
+ scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num)
+ scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width)
+
+ # calculate heatmaps
+ heatmaps = []
+ for threshold in [ts0, ts1, ts2, ts3]:
+ heatmap = assign_heatmap(scale, threshold=threshold)
+ heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy()
+ heatmap = (heatmap * 255.).astype(np.uint8)
+ heatmaps.append(heatmap)
+ return heatmaps
\ No newline at end of file
diff --git a/refnet/sampling/sampler.py b/refnet/sampling/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc3305c9f223d9dc26136301cd3a3534057749ff
--- /dev/null
+++ b/refnet/sampling/sampler.py
@@ -0,0 +1,192 @@
+import dataclasses
+import torch
+import k_diffusion
+import inspect
+
+from types import SimpleNamespace
+from refnet.util import default
+from .scheduler import schedulers, schedulers_map
+from .denoiser import CFGDenoiser
+
+defaults = SimpleNamespace(**{
+ "eta_ddim": 0.0,
+ "eta_ancestral": 1.0,
+ "ddim_discretize": "uniform",
+ "s_churn": 0.0,
+ "s_tmin": 0.0,
+ "s_noise": 1.0,
+ "k_sched_type": "Automatic",
+ "sigma_min": 0.0,
+ "sigma_max": 0.0,
+ "rho": 0.0,
+ "eta_noise_seed_delta": 0,
+ "always_discard_next_to_last_sigma": False,
+})
+
+@dataclasses.dataclass
+class Sampler:
+ label: str
+ funcname: str
+ aliases: any
+ options: dict
+
+
+samplers_k_diffusion = [
+ Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
+ Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
+ Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
+ Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
+ Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
+ Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
+ Sampler('Euler', 'sample_euler', ['k_euler'], {}),
+ Sampler('LMS', 'sample_lms', ['k_lms'], {}),
+ Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
+ Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
+ Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True})
+]
+
+sampler_extra_params = {
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_fast': ['s_noise'],
+ 'sample_dpm_2_ancestral': ['s_noise'],
+ 'sample_dpmpp_2s_ancestral': ['s_noise'],
+ 'sample_dpmpp_sde': ['s_noise'],
+ 'sample_dpmpp_2m_sde': ['s_noise'],
+ 'sample_dpmpp_3m_sde': ['s_noise'],
+}
+
+def kdiffusion_sampler_list():
+ return [k.label for k in samplers_k_diffusion]
+
+
+k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion}
+k_diffusion_scheduler = {x.name: x.function for x in schedulers}
+
+def exists(v):
+ return v is not None
+
+
+class KDiffusionSampler:
+ def __init__(self, sampler, scheduler, sd, device):
+ # k_diffusion_samplers_map[]
+ self.config = k_diffusion_samplers_map[sampler]
+ funcname = self.config.funcname
+
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname)
+ self.scheduler_name = scheduler
+ self.sd = CFGDenoiser(sd, device)
+ self.model_wrap = self.sd.model_wrap
+ self.device = device
+
+ self.s_min_uncond = None
+ self.s_churn = 0.0
+ self.s_tmin = 0.0
+ self.s_tmax = float('inf')
+ self.s_noise = 1.0
+
+ self.eta_option_field = 'eta_ancestral'
+ self.eta_infotext_field = 'Eta'
+ self.eta_default = 1.0
+ self.eta = None
+
+ self.extra_params = []
+
+ if exists(sd.sigma_max) and exists(sd.sigma_min):
+ self.model_wrap.sigmas[-1] = sd.sigma_max
+ self.model_wrap.sigmas[0] = sd.sigma_min
+
+ def initialize(self):
+ self.eta = getattr(defaults, self.eta_option_field, 0.0)
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(self, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['eta'] = self.eta
+
+ if len(self.extra_params) > 0:
+ s_churn = getattr(defaults, 's_churn', self.s_churn)
+ s_tmin = getattr(defaults, 's_tmin', self.s_tmin)
+ s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf
+ s_noise = getattr(defaults, 's_noise', self.s_noise)
+
+ if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
+ extra_params_kwargs['s_churn'] = s_churn
+ self.s_churn = s_churn
+ if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
+ extra_params_kwargs['s_tmin'] = s_tmin
+ self.s_tmin = s_tmin
+ if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
+ extra_params_kwargs['s_tmax'] = s_tmax
+ self.s_tmax = s_tmax
+ if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
+ extra_params_kwargs['s_noise'] = s_noise
+ self.s_noise = s_noise
+
+ return extra_params_kwargs
+
+ def create_noise_sampler(self, x, sigmas, seed):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
+
+ def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ if self.scheduler_name == 'Automatic':
+ self.scheduler_name = self.config.options.get('scheduler', None)
+
+ scheduler = schedulers_map.get(self.scheduler_name)
+ sigma_min = default(sigmas_min, self.model_wrap.sigma_min)
+ sigma_max = default(sigmas_max, self.model_wrap.sigma_max)
+
+ if scheduler is None or scheduler.function is None:
+ sigmas = self.model_wrap.get_sigmas(steps)
+ else:
+ sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
+
+ if scheduler.need_inner_model:
+ sigmas_kwargs['inner_model'] = self.model_wrap
+
+ sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device)
+
+ if discard_next_to_last_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+
+ def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None):
+ x = x * sigmas[0]
+
+ extra_params_kwargs = self.initialize()
+ parameters = inspect.signature(self.func).parameters
+
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+
+ if 'sigma_min' in parameters:
+ extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min()
+ extra_params_kwargs['sigma_max'] = sigmas.max()
+
+ if 'sigmas' in parameters:
+ extra_params_kwargs['sigmas'] = sigmas
+
+ if self.config.options.get('brownian_noise', False):
+ noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
+ if self.config.options.get('solver_type', None) == 'heun':
+ extra_params_kwargs['solver_type'] = 'heun'
+
+ return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs)
diff --git a/refnet/sampling/scheduler.py b/refnet/sampling/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2aadb9806ea66dcafa06a0d4ca8d163227a48ca
--- /dev/null
+++ b/refnet/sampling/scheduler.py
@@ -0,0 +1,42 @@
+import torch
+import k_diffusion
+import dataclasses
+
+@dataclasses.dataclass
+class Scheduler:
+ name: str
+ label: str
+ function: any
+
+ default_rho: float = -1
+ need_inner_model: bool = False
+ aliases: list = None
+
+
+def uniform(n, sigma_min, sigma_max, inner_model, device):
+ return inner_model.get_sigmas(n)
+
+
+def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
+ sigs = [
+ inner_model.t_to_sigma(ts)
+ for ts in torch.linspace(start, end, n + 1)[:-1]
+ ]
+ sigs += [0.0]
+ return torch.FloatTensor(sigs).to(device)
+
+schedulers = [
+ Scheduler('automatic', 'Automatic', None),
+ Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
+ Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
+ Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
+ Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
+ Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
+]
+
+def get_noise_schedulers():
+ return [scheduler.label for scheduler in schedulers]
+
+schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
\ No newline at end of file
diff --git a/refnet/sampling/tps_transformation.py b/refnet/sampling/tps_transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43a88d6af4a1340ec608f23e0b78349b3a05cca
--- /dev/null
+++ b/refnet/sampling/tps_transformation.py
@@ -0,0 +1,203 @@
+'''
+Calculate warped image using control point manipulation on a thin plate (TPS)
+Based on Herve Lombaert's 2006 web article
+"Manual Registration with Thin Plates"
+(https://profs.etsmtl.ca/hlombaert/thinplates/)
+
+Implementation by Yucheol Jung
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import PIL.Image as Image
+import torchvision.transforms as tf
+
+
+def tps_warp(images, num_points=10, perturbation_strength=10, random=True, pts_before=None, pts_after=None):
+ if random:
+ b, c, h, w = images.shape
+ device, dtype = images.device, images.dtype
+ pts_before = torch.rand([b, num_points, 2], dtype=dtype, device=device) * torch.Tensor([[[h, w]]]).to(device)
+ pts_after = pts_before + torch.randn([b, num_points, 2], dtype=dtype, device=device) * perturbation_strength
+ return _tps_warp(images, pts_before, pts_after)
+
+def _tps_warp(im, pts_before, pts_after, normalize=True):
+ '''
+ Deforms image according to movement of pts_before and pts_after
+
+ Args)
+ im torch.Tensor object of size NxCxHxW
+ pts_before torch.Tensor object of size NxTx2 (T is # control pts)
+ pts_after torch.Tensor object of size NxTx2 (T is # control pts)
+ '''
+ # check input requirements
+ assert (4 == im.dim())
+ assert (3 == pts_after.dim())
+ assert (3 == pts_before.dim())
+ N = im.size()[0]
+ assert (N == pts_after.size()[0] and N == pts_before.size()[0])
+ assert (2 == pts_after.size()[2] and 2 == pts_before.size()[2])
+ T = pts_after.size()[1]
+ assert (T == pts_before.size()[1])
+ H = im.size()[2]
+ W = im.size()[3]
+
+ if normalize:
+ pts_after = pts_after.clone()
+ pts_after[:, :, 0] /= 0.5 * W
+ pts_after[:, :, 1] /= 0.5 * H
+ pts_after -= 1
+ pts_before = pts_before.clone()
+ pts_before[:, :, 0] /= 0.5 * W
+ pts_before[:, :, 1] /= 0.5 * H
+ pts_before -= 1
+
+ def construct_P():
+ '''
+ Consturcts matrix P of size NxTx3 where
+ P[n,i,0] := 1
+ P[n,i,1:] := pts_after[n]
+ '''
+ # Create matrix P with same configuration as 'pts_after'
+ P = pts_after.new_zeros((N, T, 3))
+ P[:, :, 0] = 1
+ P[:, :, 1:] = pts_after
+
+ return P
+
+ def calc_U(pt1, pt2):
+ '''
+ Calculate distance U between pt1 and pt2
+
+ U(r) := r**2 * log(r)
+ where
+ r := |pt1 - pt2|_2
+
+ Args)
+ pt1 torch.Tensor object, last dim is always 2
+ pt2 torch.Tensor object, last dim is always 2
+ '''
+ assert (2 == pt1.size()[-1])
+ assert (2 == pt2.size()[-1])
+
+ diff = pt1 - pt2
+ sq_diff = diff ** 2
+ sq_diff_sum = sq_diff.sum(-1)
+ r = sq_diff_sum.sqrt()
+
+ # Adds 1e-6 for numerical stability
+ return (r ** 2) * torch.log(r + 1e-6)
+
+ def construct_K():
+ '''
+ Consturcts matrix K of size NxTxT where
+ K[n,i,j] := U(|pts_after[n,i] - pts_after[n,j]|_2)
+ '''
+
+ # Assuming the number of control points are small enough,
+ # We just use for-loop for easy-to-read code
+
+ # Create matrix K with same configuration as 'pts_after'
+ K = pts_after.new_zeros((N, T, T))
+ for i in range(T):
+ for j in range(T):
+ K[:, i, j] = calc_U(pts_after[:, i, :], pts_after[:, j, :])
+
+ return K
+
+ def construct_L():
+ '''
+ Consturcts matrix L of size Nx(T+3)x(T+3) where
+ L[n] = [[ K[n] P[n] ]]
+ [[ P[n]^T 0 ]]
+ '''
+ P = construct_P()
+ K = construct_K()
+
+ # Create matrix L with same configuration as 'K'
+ L = K.new_zeros((N, T + 3, T + 3))
+
+ # Fill L matrix
+ L[:, :T, :T] = K
+ L[:, :T, T:(T + 3)] = P
+ L[:, T:(T + 3), :T] = P.transpose(1, 2)
+
+ return L
+
+ def construct_uv_grid():
+ '''
+ Returns H x W x 2 tensor uv with UV coordinate as its elements
+ uv[:,:,0] is H x W grid of x values
+ uv[:,:,1] is H x W grid of y values
+ '''
+ u_range = torch.arange(
+ start=-1.0, end=1.0, step=2.0 / W, device=im.device)
+ assert (W == u_range.size()[0])
+ u = u_range.new_zeros((H, W))
+ u[:] = u_range
+
+ v_range = torch.arange(
+ start=-1.0, end=1.0, step=2.0 / H, device=im.device)
+ assert (H == v_range.size()[0])
+ vt = v_range.new_zeros((W, H))
+ vt[:] = v_range
+ v = vt.transpose(0, 1)
+
+ return torch.stack([u, v], dim=2)
+
+ L = construct_L()
+ VT = pts_before.new_zeros((N, T + 3, 2))
+ # Use delta x and delta y as known heights of the surface
+ VT[:, :T, :] = pts_before - pts_after
+
+ # Solve Lx = VT
+ # x is of shape (N, T+3, 2)
+ # x[:,:,0] represents surface parameters for dx surface
+ # (dx values as surface height (z))
+ # x[:,:,1] represents surface parameters for dy surface
+ # (dy values as surface height (z))
+ x = torch.linalg.solve(L, VT)
+
+ uv = construct_uv_grid()
+ uv_batch = uv.repeat((N, 1, 1, 1))
+
+ def calc_dxdy():
+ '''
+ Calculate surface height for each uv coordinate
+
+ Returns NxHxWx2 tensor
+ '''
+
+ # control points of size NxTxHxWx2
+ cp = uv.new_zeros((H, W, N, T, 2))
+ cp[:, :, :] = pts_after
+ cp = cp.permute([2, 3, 0, 1, 4])
+
+ U = calc_U(uv, cp) # U value matrix of size NxTxHxW
+ w, a = x[:, :T, :], x[:, T:, :] # w is of size NxTx2, a is of size Nx3x2
+ w_x, w_y = w[:, :, 0], w[:, :, 1] # NxT each
+ a_x, a_y = a[:, :, 0], a[:, :, 1] # Nx3 each
+ dx = (
+ a_x[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
+ torch.einsum('nhwd,nd->nhw', uv_batch, a_x[:, 1:]) +
+ torch.einsum('nthw,nt->nhw', U, w_x)) # dx values of NxHxW
+ dy = (
+ a_y[:, 0].repeat((H, W, 1)).permute(2, 0, 1) +
+ torch.einsum('nhwd,nd->nhw', uv_batch, a_y[:, 1:]) +
+ torch.einsum('nthw,nt->nhw', U, w_y)) # dy values of NxHxW
+
+ return torch.stack([dx, dy], dim=3)
+
+ dxdy = calc_dxdy()
+ flow_field = uv + dxdy
+
+ return F.grid_sample(im, flow_field.to(im.dtype))
+
+if __name__ == '__main__':
+ num_points = 10
+ perturbation_strength = 10
+ img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB")).unsqueeze(0)
+ # img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB").resize((224, 224))).unsqueeze(0)
+ img = tps_warp(img, num_points= num_points, perturbation_strength = perturbation_strength).squeeze(0)
+ img = tf.ToPILImage()(img)
+ img.show()
\ No newline at end of file
diff --git a/refnet/util.py b/refnet/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c430761864053f13e1b494be454d34cd4d6650
--- /dev/null
+++ b/refnet/util.py
@@ -0,0 +1,200 @@
+import re
+import os.path as osp
+
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as tf
+from torch.utils.checkpoint import checkpoint
+
+import numpy as np
+import itertools
+import importlib
+
+from tqdm import tqdm
+from inspect import isfunction
+from functools import wraps
+from safetensors import safe_open
+
+
+
+def exists(x):
+ return x is not None
+
+def append_dims(x, target_dims) -> torch.Tensor:
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def expand_to_batch_size(x, bs):
+ if isinstance(x, list):
+ x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x]
+ else:
+ x = x.repeat(bs, *([1] * (len(x.shape) - 1)))
+ return x
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"):
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode)
+
+def get_crop_scale(h, w, bgh, bgw):
+ gen_aspect = w / h
+ bg_aspect = bgw / bgh
+ if gen_aspect > bg_aspect:
+ cw = 1.0
+ ch = (h / w) * (bgw / bgh)
+ else:
+ ch = 1.0
+ cw = (w / h) * (bgh / bgw)
+ return ch, cw
+
+def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"):
+ assert len(x.shape) == 4
+ return F.interpolate(x, size=target_size, mode=interpolation_mode)
+
+def resize_and_crop(x: torch.Tensor, ch, cw, th, tw):
+ b, c, h, w = x.shape
+ return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw])
+
+
+def fitting_weights(model, sd):
+ n_params = len([name for name, _ in
+ itertools.chain(model.named_parameters(),
+ model.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(model.named_parameters(),
+ model.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ device = old_param.device
+ if len(new_shape) == 1:
+ # Vectorized 1D case
+ new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]]
+ elif len(new_shape) >= 2:
+ # Vectorized 2D case
+ i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0]
+ j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1]
+
+ # Use advanced indexing to extract all values at once
+ new_param = old_param[i_indices, j_indices]
+
+ # Count how many times each old column is used
+ n_used_old = torch.bincount(
+ torch.arange(new_shape[1], device=device) % old_shape[1],
+ minlength=old_shape[1]
+ )
+
+ # Map to new shape
+ n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]]
+
+ # Reshape for broadcasting
+ n_used_new = n_used_new.reshape(1, new_shape[1])
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+
+ # Normalize
+ new_param = new_param / n_used_new
+
+ sd[name] = new_param
+ return sd
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
+
+def load_weights(path, weights_only=True):
+ ext = osp.splitext(path)[-1]
+ assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}"
+ if ext == ".safetensors":
+ sd = {}
+ safe_sd = safe_open(path, framework="pt", device="cpu")
+ for key in safe_sd.keys():
+ sd[key] = safe_sd.get_tensor(key)
+ else:
+ sd = torch.load(path, map_location="cpu", weights_only=weights_only)
+ if "state_dict" in sd.keys():
+ sd = sd["state_dict"]
+ return sd
+
+
+def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()):
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in delete_keys:
+ if len(skip_keys) > 0:
+ for sk in skip_keys:
+ if re.match(ik, k) is not None and re.match(sk, k) is None:
+ del sd[k]
+ else:
+ if re.match(ik, k) is not None:
+ del sd[k]
+ return sd
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def checkpoint_wrapper(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if not hasattr(self, 'checkpoint') or self.checkpoint:
+ def bound_func(*args, **kwargs):
+ return func(self, *args, **kwargs)
+ return checkpoint(bound_func, *args, use_reentrant=False, **kwargs)
+ else:
+ return func(self, *args, **kwargs)
+ return wrapper