diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1339be116951e45a93f4e67ceda40cd8d78dee63 --- /dev/null +++ b/app.py @@ -0,0 +1,227 @@ +import gradio as gr +import argparse + +from refnet.sampling import get_noise_schedulers, get_sampler_list +from functools import partial +from backend import * + +links = { + "base": "https://arxiv.org/abs/2401.01456", + "v1": "https://openaccess.thecvf.com/content/WACV2025/html/Yan_ColorizeDiffusion_Improving_Reference-Based_Sketch_Colorization_with_Latent_Diffusion_Model_WACV_2025_paper.html", + "v1.5": "https://arxiv.org/abs/2502.19937v1", + "v2": "https://arxiv.org/abs/2504.06895", + "xl": "https://arxiv.org/abs/2601.04883", + "weights": "https://huggingface.co/tellurion/colorizer/tree/main", + "github": "https://github.com/tellurion-kanata/colorizeDiffusion", +} + +def app_options(): + parser = argparse.ArgumentParser() + parser.add_argument("--server_name", '-addr', type=str, default="0.0.0.0") + parser.add_argument("--server_port", '-port', type=int, default=7860) + parser.add_argument("--share", action="store_true") + parser.add_argument("--enable_text_manipulation", '-manipulate', action="store_true") + return parser.parse_args() + + +def init_interface(opt, *args, **kwargs) -> None: + sampler_list = get_sampler_list() + scheduler_list = get_noise_schedulers() + + img_block = partial(gr.Image, type="pil", height=300, interactive=True, show_label=True, format="png") + with gr.Blocks( + title = "Colorize Diffusion", + css_paths = "backend/style.css", + theme = gr.themes.Ocean(), + elem_id = "main-interface", + analytics_enabled = False, + fill_width = True + ) as block: + with gr.Row(elem_id="header-row", equal_height=True, variant="panel"): + gr.Markdown(f"""
+
🎨Colorize Diffusion
+ +
""") + + with gr.Row(elem_id="content-row", equal_height=False, variant="panel"): + with gr.Column(): + with gr.Row(visible=opt.enable_text_manipulation): + target = gr.Textbox(label="Target prompt", value="", scale=2) + anchor = gr.Textbox(label="Anchor prompt", value="", scale=2) + control = gr.Textbox(label="Control prompt", value="", scale=2) + with gr.Row(visible=opt.enable_text_manipulation): + target_scale = gr.Slider(label="Target scale", value=0.0, minimum=0, maximum=15.0, step=0.25, scale=2) + ts0 = gr.Slider(label="Threshold 0", value=0.5, minimum=0, maximum=1.0, step=0.01) + ts1 = gr.Slider(label="Threshold 1", value=0.55, minimum=0, maximum=1.0, step=0.01) + ts2 = gr.Slider(label="Threshold 2", value=0.65, minimum=0, maximum=1.0, step=0.01) + ts3 = gr.Slider(label="Threshold 3", value=0.95, minimum=0, maximum=1.0, step=0.01) + with gr.Row(visible=opt.enable_text_manipulation): + enhance = gr.Checkbox(label="Enhance manipulation", value=False) + add_prompt = gr.Button(value="Add") + clear_prompt = gr.Button(value="Clear") + vis_button = gr.Button(value="Visualize") + text_prompt = gr.Textbox(label="Final prompt", value="", lines=3, visible=opt.enable_text_manipulation) + + with gr.Row(): + sketch_img = img_block(label="Sketch") + reference_img = img_block(label="Reference") + background_img = img_block(label="Background") + + style_enhance = gr.State(False) + fg_enhance = gr.State(False) + with gr.Row(): + bg_enhance = gr.Checkbox(label="Low-level injection", value=False) + injection = gr.Checkbox(label="Attention injection", value=False) + autofit_size = gr.Checkbox(label="Autofit size", value=False) + with gr.Row(): + gs_r = gr.Slider(label="Reference guidance scale", minimum=1, maximum=15.0, value=4.0, step=0.5) + strength = gr.Slider(label="Reference strength", minimum=0, maximum=1, value=1, step=0.05) + fg_strength = gr.Slider(label="Foreground strength", minimum=0, maximum=1, value=1, step=0.05) + bg_strength = gr.Slider(label="Background strength", minimum=0, maximum=1, value=1, step=0.05) + with gr.Row(): + gs_s = gr.Slider(label="Sketch guidance scale", minimum=1, maximum=5.0, value=1.0, step=0.1) + ctl_scale = gr.Slider(label="Sketch strength", minimum=0, maximum=3, value=1, step=0.05) + mask_scale = gr.Slider(label="Background factor", minimum=0, maximum=2, value=1, step=0.05) + merge_scale = gr.Slider(label="Merging scale", minimum=0, maximum=1, value=0, step=0.05) + with gr.Row(): + bs = gr.Slider(label="Batch size", minimum=1, maximum=4, value=1, step=1, scale=1) + width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=32, scale=2) + with gr.Row(): + step = gr.Slider(label="Step", minimum=1, maximum=100, value=20, step=1, scale=1) + height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=32, scale=2) + + seed = gr.Slider(label="Seed", minimum=-1, maximum=MAXM_INT32, step=1, value=-1) + with gr.Accordion("Advanced Settings", open=False): + with gr.Row(): + crop = gr.Checkbox(label="Crop result", value=False, scale=1) + remove_fg = gr.Checkbox(label="Remove foreground in background input", value=False, scale=2) + rmbg = gr.Checkbox(label="Remove background in result", value=False, scale=2) + latent_inpaint = gr.Checkbox(label="Latent copy BG input", value=False, scale=2) + with gr.Row(): + injection_control_scale = gr.Slider(label="Injection fidelity (sketch)", minimum=0.0, + maximum=2.0, value=0, step=0.05) + injection_fidelity = gr.Slider(label="Injection fidelity (reference)", minimum=0.0, + maximum=1.0, value=0.5, step=0.05) + injection_start_step = gr.Slider(label="Injection start step", minimum=0.0, maximum=1.0, + value=0, step=0.05) + + with gr.Row(): + reuse_seed = gr.Button(value="Reuse Seed") + random_seed = gr.Button(value="Random Seed") + + with gr.Column(): + result_gallery = gr.Gallery( + label='Output', show_label=False, elem_id="gallery", preview=True, type="pil", format="png" + ) + run_button = gr.Button("Generate", variant="primary", size="lg") + with gr.Row(): + mask_ts = gr.Slider(label="Reference mask threshold", minimum=0., maximum=1., value=0.5, step=0.01) + mask_ss = gr.Slider(label="Sketch mask threshold", minimum=0., maximum=1., value=0.05, step=0.01) + pad_scale = gr.Slider(label="Reference padding scale", minimum=1, maximum=2, value=1, step=0.05) + + with gr.Row(): + sd_model = gr.Dropdown(choices=get_available_models(), label="Models", + value=get_available_models()[0]) + extractor_model = gr.Dropdown(choices=line_extractor_list, + label="Line extractor", value=default_line_extractor) + mask_model = gr.Dropdown(choices=mask_extractor_list, label="Reference mask extractor", + value=default_mask_extractor) + with gr.Row(): + sampler = gr.Dropdown(choices=sampler_list, value="DPM++ 3M SDE", label="Sampler") + scheduler = gr.Dropdown(choices=scheduler_list, value=scheduler_list[0], label="Noise scheduler") + preprocessor = gr.Dropdown(choices=["none", "extract", "invert", "invert-webui"], + label="Sketch preprocessor", value="invert") + + with gr.Row(): + deterministic = gr.Checkbox(label="Deterministic batch seed", value=False) + save_memory = gr.Checkbox(label="Save memory", value=True) + + # Hidden states for unused advanced controls + fg_disentangle_scale = gr.State(1.0) + start_step = gr.State(0.0) + end_step = gr.State(1.0) + no_start_step = gr.State(-0.05) + no_end_step = gr.State(-0.05) + return_inter = gr.State(False) + accurate = gr.State(False) + enc_scale = gr.State(1.0) + middle_scale = gr.State(1.0) + low_scale = gr.State(1.0) + ctl_scale_1 = gr.State(1.0) + ctl_scale_2 = gr.State(1.0) + ctl_scale_3 = gr.State(1.0) + ctl_scale_4 = gr.State(1.0) + + add_prompt.click(fn=apppend_prompt, + inputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt], + outputs=[target, anchor, control, target_scale, enhance, ts0, ts1, ts2, ts3, text_prompt]) + clear_prompt.click(fn=clear_prompts, outputs=[text_prompt]) + + reuse_seed.click(fn=get_last_seed, outputs=[seed]) + random_seed.click(fn=reset_random_seed, outputs=[seed]) + + extractor_model.input(fn=switch_extractor, inputs=[extractor_model]) + sd_model.input(fn=load_model, inputs=[sd_model]) + mask_model.input(fn=switch_mask_extractor, inputs=[mask_model]) + + ips = [style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, + bs, sketch_img, reference_img, background_img, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, + ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, fg_strength, bg_strength, merge_scale, + mask_scale, height, width, seed, save_memory, step, injection, autofit_size, + remove_fg, rmbg, latent_inpaint, injection_control_scale, injection_fidelity, injection_start_step, + crop, pad_scale, start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, + preprocessor, deterministic, text_prompt, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, + enhance, accurate, enc_scale, middle_scale, low_scale, strength] + + run_button.click( + fn = inference, + inputs = ips, + outputs = [result_gallery], + ) + + vis_button.click( + fn = visualize, + inputs = [reference_img, text_prompt, control, ts0, ts1, ts2, ts3], + outputs = [result_gallery], + ) + + block.launch( + server_name = opt.server_name, + share = opt.share, + server_port = opt.server_port, + ) + + +if __name__ == '__main__': + opt = app_options() + try: + models = get_available_models() + load_model(models[0]) + switch_extractor(default_line_extractor) + switch_mask_extractor(default_mask_extractor) + interface = init_interface(opt) + except Exception as e: + print(f"Error initializing interface: {e}") + raise diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b265edbc6413b6c5d7871b5d7dcfb15067c2be1a --- /dev/null +++ b/backend/__init__.py @@ -0,0 +1,16 @@ +from .appfunc import * + + +__all__ = [ + 'switch_extractor', 'switch_mask_extractor', + 'get_available_models', 'load_model', 'inference', 'reset_random_seed', 'get_last_seed', + 'apppend_prompt', 'clear_prompts', 'visualize', + 'default_line_extractor', 'default_mask_extractor', 'MAXM_INT32', + 'mask_extractor_list', 'line_extractor_list', +] + + +default_line_extractor = "lineart_keras" +default_mask_extractor = "rmbg-v2" +mask_extractor_list = ["none", "ISNet", "rmbg-v2", "BiRefNet", "BiRefNet_HR"] +line_extractor_list = ["lineart", "lineart_denoise", "lineart_keras", "lineart_sk"] diff --git a/backend/appfunc.py b/backend/appfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..ba75de23089a17cce6a9edd15f1dcf520895a963 --- /dev/null +++ b/backend/appfunc.py @@ -0,0 +1,298 @@ +import os +import random +import traceback +import gradio as gr +import os.path as osp + +from huggingface_hub import hf_hub_download + +from omegaconf import OmegaConf +from refnet.util import instantiate_from_config +from preprocessor import create_model +from .functool import * + +model = None + +model_type = "" +current_checkpoint = "" +global_seed = None + +smask_extractor = create_model("ISNet-sketch").cpu() + +MAXM_INT32 = 429496729 + +# HuggingFace model repository +HF_REPO_ID = "tellurion/colorizer" +MODEL_CACHE_DIR = "models" + +# Model registry: filename -> model_type +MODEL_REGISTRY = { + "sdxl.safetensors": "sdxl", + "xlv2.safetensors": "xlv2", +} + +model_types = ["sdxl", "xlv2"] + +''' + Gradio UI functions +''' + + +def get_available_models(): + """Return list of available model names from registry.""" + return list(MODEL_REGISTRY.keys()) + + +def download_model(filename): + """Download a model from HuggingFace Hub if not already cached.""" + os.makedirs(MODEL_CACHE_DIR, exist_ok=True) + local_path = osp.join(MODEL_CACHE_DIR, filename) + if osp.exists(local_path): + return local_path + + print(f"Downloading {filename} from {HF_REPO_ID}...") + gr.Info(f"Downloading {filename}...") + path = hf_hub_download( + repo_id=HF_REPO_ID, + filename=filename, + local_dir=MODEL_CACHE_DIR, + ) + print(f"Downloaded to {path}") + return path + + +def switch_extractor(type): + global line_extractor + try: + line_extractor = create_model(type) + gr.Info(f"Switched to {type} extractor") + except Exception as e: + print(f"Error info: {e}") + print(traceback.print_exc()) + gr.Info(f"Failed in loading {type} extractor") + + +def switch_mask_extractor(type): + global mask_extractor + try: + mask_extractor = create_model(type) + gr.Info(f"Switched to {type} extractor") + except Exception as e: + print(f"Error info: {e}") + print(traceback.print_exc()) + gr.Info(f"Failed in loading {type} extractor") + + +def apppend_prompt(target, anchor, control, scale, enhance, ts0, ts1, ts2, ts3, prompt): + target = target.strip() + anchor = anchor.strip() + control = control.strip() + if target == "": target = "none" + if anchor == "": anchor = "none" + if control == "": control = "none" + new_p = (f"\n[target] {target}; [anchor] {anchor}; [control] {control}; [scale] {str(scale)}; " + f"[enhanced] {str(enhance)}; [ts0] {str(ts0)}; [ts1] {str(ts1)}; [ts2] {str(ts2)}; [ts3] {str(ts3)}") + return "", "", "", 0.0, False, 0.5, 0.55, 0.65, 0.95, (prompt + new_p).strip() + + +def clear_prompts(): + return "" + + +def load_model(ckpt_name): + global model, model_type, current_checkpoint + config_root = "configs/inference" + + try: + # Determine model type from registry or filename prefix + new_model_type = MODEL_REGISTRY.get(ckpt_name, "") + if not new_model_type: + for key in model_types: + if ckpt_name.startswith(key): + new_model_type = key + break + + if model_type != new_model_type or not "model" in globals(): + if "model" in globals() and exists(model): + del model + config_path = osp.join(config_root, f"{new_model_type}.yaml") + new_model = instantiate_from_config(OmegaConf.load(config_path).model).cpu().eval() + print(f"Switched to {new_model_type} model, loading weights from [{ckpt_name}]...") + model = new_model + + # Download model from HF Hub + local_path = download_model(ckpt_name) + + model.parameterization = "eps" if ckpt_name.find("eps") > -1 else "v" + model.init_from_ckpt(local_path, logging=True) + model.switch_to_fp16() + + model_type = new_model_type + current_checkpoint = ckpt_name + print(f"Loaded model from [{ckpt_name}], model_type [{model_type}].") + gr.Info("Loaded model successfully.") + + except Exception as e: + print(f"Error type: {e}") + print(traceback.print_exc()) + gr.Info("Failed in loading model.") + + +def get_last_seed(): + return global_seed or -1 + + +def reset_random_seed(): + return -1 + + +def visualize(reference, text, *args): + return visualize_heatmaps(model, reference, parse_prompts(text), *args) + + +def set_cas_scales(accurate, cas_args): + enc_scale, middle_scale, low_scale, strength = cas_args[:4] + if not accurate: + scale_strength = { + "level_control": True, + "scales": { + "encoder": enc_scale * strength, + "middle": middle_scale * strength, + "low": low_scale * strength, + } + } + else: + scale_strength = { + "level_control": False, + "scales": list(cas_args[4:]) + } + return scale_strength + + +@torch.no_grad() +def inference( + style_enhance, bg_enhance, fg_enhance, fg_disentangle_scale, + bs, input_s, input_r, input_bg, mask_ts, mask_ss, gs_r, gs_s, ctl_scale, + ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4, + fg_strength, bg_strength, merge_scale, mask_scale, height, width, seed, low_vram, step, + injection, autofit_size, remove_fg, rmbg, latent_inpaint, infid_x, infid_r, injstep, crop, pad_scale, + start_step, end_step, no_start_step, no_end_step, return_inter, sampler, scheduler, preprocess, + deterministic, text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance, accurate, + *args +): + global global_seed, line_extractor, mask_extractor + global_seed = seed if seed > -1 else random.randint(0, MAXM_INT32) + torch.manual_seed(global_seed) + + # Auto-fit size based on sketch dimensions + if autofit_size and exists(input_s): + sketch_w, sketch_h = input_s.size + aspect_ratio = sketch_w / sketch_h + target_area = 1024 * 1024 + new_h = int((target_area / aspect_ratio) ** 0.5) + new_w = int(new_h * aspect_ratio) + height = ((new_h + 16) // 32) * 32 + width = ((new_w + 16) // 32) * 32 + height = max(768, min(1536, height)) + width = max(768, min(1536, width)) + gr.Info(f"Auto-fitted size: {width}x{height}") + + smask, rmask, bgmask = None, None, None + manipulation_params = parse_prompts(text, target, anchor, control, target_scale, ts0, ts1, ts2, ts3, enhance) + inputs = preprocessing_inputs( + sketch = input_s, + reference = input_r, + background = input_bg, + preprocess = preprocess, + hook = injection, + resolution = (height, width), + extractor = line_extractor, + pad_scale = pad_scale, + ) + sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch = inputs + + cond = {"reference": reference, "sketch": sketch, "background": background} + mask_guided = bg_enhance or fg_enhance + + if exists(white_sketch) and exists(reference) and mask_guided: + mask_extractor.cuda() + smask_extractor.cuda() + smask = smask_extractor.proceed( + x=white_sketch, pil_x=input_s, th=height, tw=width, threshold=mask_ss, crop=False + ) + + if exists(background) and remove_fg: + bgmask = mask_extractor.proceed(x=background, pil_x=input_bg, threshold=mask_ts, dilate=True) + filtered_background = torch.where(bgmask < mask_ts, background, torch.ones_like(background)) + cond.update({"background": filtered_background, "rmask": bgmask}) + else: + rmask = mask_extractor.proceed(x=reference, pil_x=input_r, threshold=mask_ts, dilate=True) + cond.update({"rmask": rmask}) + rmask = torch.where(rmask > 0.5, torch.ones_like(rmask), torch.zeros_like(rmask)) + cond.update({"smask": smask}) + smask_extractor.cpu() + mask_extractor.cpu() + + scale_strength = set_cas_scales(accurate, args) + ctl_scales = [ctl_scale_1, ctl_scale_2, ctl_scale_3, ctl_scale_4] + ctl_scales = [t * ctl_scale for t in ctl_scales] + + results = model.generate( + # Colorization mode + style_enhance = style_enhance, + bg_enhance = bg_enhance, + fg_enhance = fg_enhance, + fg_disentangle_scale = fg_disentangle_scale, + latent_inpaint = latent_inpaint, + + # Conditional inputs + cond = cond, + ctl_scale = ctl_scales, + merge_scale = merge_scale, + mask_scale = mask_scale, + mask_thresh = mask_ts, + mask_thresh_sketch = mask_ss, + + # Sampling settings + bs = bs, + gs = [gs_r, gs_s], + sampler = sampler, + scheduler = scheduler, + start_step = start_step, + end_step = end_step, + no_start_step = no_start_step, + no_end_step = no_end_step, + strength = scale_strength, + fg_strength = fg_strength, + bg_strength = bg_strength, + seed = global_seed, + deterministic = deterministic, + height = height, + width = width, + step = step, + + # Injection settings + injection = injection, + injection_cfg = infid_r, + injection_control = infid_x, + injection_start_step = injstep, + hook_xr = inject_xr, + hook_xs = inject_xs, + + # Additional settings + low_vram = low_vram, + return_intermediate = return_inter, + manipulation_params = manipulation_params, + ) + + if rmbg: + mask_extractor.cuda() + mask = smask_extractor.proceed(x=-sketch, threshold=mask_ss).repeat(results.shape[0], 1, 1, 1) + results = torch.where(mask >= mask_ss, results, torch.ones_like(results)) + mask_extractor.cpu() + + results = postprocess(results, sketch, reference, background, crop, original_shape, + mask_guided, smask, rmask, bgmask, mask_ts, mask_ss) + torch.cuda.empty_cache() + gr.Info("Generation completed.") + return results diff --git a/backend/functool.py b/backend/functool.py new file mode 100644 index 0000000000000000000000000000000000000000..45db522c469d10dba1dab2459cc4e022e8a7263c --- /dev/null +++ b/backend/functool.py @@ -0,0 +1,276 @@ +import cv2 +import numpy as np +import PIL.Image as Image + +import torch +import torch.nn as nn +import torchvision.transforms as transforms + +from functools import partial + +maxium_resolution = 4096 +token_length = int(256 ** 0.5) + +def exists(v): + return v is not None + +resize = partial(transforms.Resize, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True) + +def resize_image(img, new_size, w, h): + if w > h: + img = resize((int(h / w * new_size), new_size))(img) + else: + img = resize((new_size, int(w / h * new_size)))(img) + return img + +def pad_image(image: torch.Tensor, h, w): + b, c, height, width = image.shape + square_image = -torch.ones([b, c, h, w], device=image.device) + left = (w - width) // 2 + top = (h - height) // 2 + square_image[:, :, top:top+height, left:left+width] = image + + return square_image, (left, top, width, height) + + +def pad_image_with_margin(image: Image, scale): + w, h = image.size + nw = int(w * scale) + bg = Image.new('RGB', (nw, h), (255, 255, 255)) + bg.paste(image, ((nw-w)//2, 0)) + return bg + + +def crop_image_from_square(square_image, original_dim): + left, top, width, height = original_dim + return square_image.crop((left, top, left + width, top + height)) + + +def to_tensor(x, inverse=False): + x = transforms.ToTensor()(x).unsqueeze(0) + x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda() + return x if not inverse else -x + +def to_numpy(x, denormalize=True): + if denormalize: + return ((x.clamp(-1, 1) + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + else: + return (x.clamp(0, 1) * 255)[0][0].cpu().numpy().astype(np.uint8) + +def lineart_standard(x: Image.Image): + x = np.array(x).astype(np.float32) + g = cv2.GaussianBlur(x, (0, 0), 6.0) + intensity = np.min(g - x, axis=2).clip(0, 255) + intensity /= max(16, np.median(intensity[intensity > 8])) + intensity *= 127 + intensity = np.repeat(np.expand_dims(intensity, 2), 3, axis=2) + result = to_tensor(intensity.clip(0, 255).astype(np.uint8)) + return result + +def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None, new=False): + w, h = sketch.size + th, tw = resolution + r = min(th/h, tw/w) + + if preprocess == "none": + sketch = to_tensor(sketch) + elif preprocess == "invert": + sketch = to_tensor(sketch, inverse=True) + elif preprocess == "invert-webui": + sketch = lineart_standard(sketch) + else: + sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1) + + sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw) + if new: + sketch = ((sketch + 1) / 2.).clamp(0, 1) + white_sketch = 1 - sketch + else: + white_sketch = -sketch + return sketch, original_shape, white_sketch + + +@torch.no_grad() +def preprocessing_inputs( + sketch: Image.Image, + reference: Image.Image, + background: Image.Image, + preprocess: str, + hook: bool, + resolution: tuple[int, int], + extractor: nn.Module, + pad_scale: float = 1., + new = False +): + extractor = extractor.cuda() + h, w = resolution + if exists(sketch): + sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor, new) + else: + sketch = torch.zeros([1, 3, h, w], device="cuda") if new else -torch.ones([1, 3, h, w], device="cuda") + white_sketch = None + original_shape = (0, 0, h, w) + + inject_xs = None + if hook: + assert exists(reference) and exists(extractor) + maxm = max(h, w) + # inject_xs = resize((h, w))(extractor.proceed(resize((maxm, maxm))(reference)).repeat(1, 3, 1, 1)) + inject_xr = to_tensor(resize((h, w))(reference)) + else: + inject_xr = None + extractor = extractor.cpu() + + if exists(reference): + if pad_scale > 1.: + reference = pad_image_with_margin(reference, pad_scale) + reference = to_tensor(reference) + + if exists(background): + if pad_scale > 1.: + background = pad_image_with_margin(background, pad_scale) + background = to_tensor(background) + + return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch + +def postprocess(results, sketch, reference, background, crop, original_shape, + mask_guided, smask, rmask, bgmask, mask_ts, mask_ss, new=False): + results = to_numpy(results) + sketch = to_numpy(sketch, not new)[0] + + results_list = [] + for result in results: + result = Image.fromarray(result) + if crop: + result = crop_image_from_square(result, original_shape) + results_list.append(result) + + results_list.append(sketch) + + if exists(reference): + reference = to_numpy(reference)[0] + results_list.append(reference) + # if vis_crossattn: + # results_list += visualize_attention_map(reference, results_list[0], vh, vw) + + if exists(background): + background = to_numpy(background)[0] + results_list.append(background) + + if exists(bgmask): + background = Image.fromarray(background) + results_list.append(Image.composite( + background, + Image.new("RGB", background.size, (255, 255, 255)), + Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L") + )) + results_list.append(Image.composite( + Image.new("RGB", background.size, (255, 255, 255)), + background, + Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L") + )) + + if mask_guided: + smask[smask < mask_ss] = 0 + results_list.append(Image.fromarray(to_numpy(smask, denormalize=False), mode="L")) + + if exists(rmask): + reference = Image.fromarray(reference) + rmask[rmask < mask_ts] = 0 + results_list.append(Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")) + results_list.append(Image.composite( + reference, + Image.new("RGB", reference.size, (255, 255, 255)), + Image.fromarray(to_numpy(rmask, denormalize=False), mode="L") + )) + results_list.append(Image.composite( + Image.new("RGB", reference.size, (255, 255, 255)), + reference, + Image.fromarray(to_numpy(rmask, denormalize=False), mode="L") + )) + + return results_list + + +def parse_prompts( + prompts: str, + target: bool = None, + anchor: bool = None, + control: bool = None, + target_scale: bool = None, + ts0: float = None, + ts1: float = None, + ts2: float = None, + ts3: float = None, + enhance: bool = None +): + + targets = [] + anchors = [] + controls = [] + scales = [] + enhances = [] + thresholds_list = [] + + replace_str = ["; [anchor] ", "; [control] ", "; [scale]", "; [enhanced]", "; [ts0]", "; [ts1]", "; [ts2]", "; [ts3]"] + if prompts != "" and prompts is not None: + ps_l = prompts.split('\n') + for ps in ps_l: + ps = ps.replace("[target] ", "") + for str in replace_str: + ps = ps.replace(str, "||||") + + p_l = ps.split("||||") + targets.append(p_l[0]) + anchors.append(p_l[1]) + controls.append(p_l[2]) + scales.append(float(p_l[3])) + enhances.append(bool(p_l[4])) + thresholds_list.append([float(p_l[5]), float(p_l[6]), float(p_l[7]), float(p_l[8])]) + + if exists(target) and target != "": + targets.append(target) + anchors.append(anchor) + controls.append(control) + scales.append(target_scale) + enhances.append(enhance) + thresholds_list.append([ts0, ts1, ts2, ts3]) + + return { + "targets": targets, + "anchors": anchors, + "controls": controls, + "target_scales": scales, + "enhances": enhances, + "thresholds_list": thresholds_list + } + + +from refnet.sampling.manipulation import get_heatmaps +def visualize_heatmaps(model, reference, manipulation_params, control, ts0, ts1, ts2, ts3): + if reference is None: + return [] + + size = reference.size + if size[0] > maxium_resolution or size[1] > maxium_resolution: + if size[0] > size[1]: + size = (maxium_resolution, int(float(maxium_resolution) / size[0] * size[1])) + else: + size = (int(float(maxium_resolution) / size[1] * size[0]), maxium_resolution) + reference = reference.resize(size, Image.BICUBIC) + + reference = np.array(reference) + scale_maps = get_heatmaps(model, to_tensor(reference), size[1], size[0], + control, ts0, ts1, ts2, ts3, **manipulation_params) + + scale_map = scale_maps[0] + scale_maps[1] + scale_maps[2] + scale_maps[3] + heatmap = cv2.cvtColor(cv2.applyColorMap(scale_map, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) + result = cv2.addWeighted(reference, 0.3, heatmap, 0.7, 0) + hu = size[1] // token_length + wu = size[0] // token_length + for i in range(16): + result[i * hu, :] = (0, 0, 0) + for i in range(16): + result[:, i * wu] = (0, 0, 0) + + return [result] \ No newline at end of file diff --git a/backend/style.css b/backend/style.css new file mode 100644 index 0000000000000000000000000000000000000000..8b9e29512a265bf4ca7e34ec4bc982a59d1421ff --- /dev/null +++ b/backend/style.css @@ -0,0 +1,181 @@ +:root { + --primary-color: #9b59b6; + --primary-light: #d6c6e1; + --secondary-color: #2ecc71; + --text-color: #333333; + --background-color: #f9f9f9; + --card-bg: #ffffff; + --border-radius: 10px; + --shadow-sm: 0 2px 5px rgba(0, 0, 0, 0.05); + --shadow-md: 0 5px 15px rgba(0, 0, 0, 0.07); + --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1); + --gradient: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); + --input-border: #e0e0e0; + --input-bg: #ffffff; + --font-weight-normal: 500; + --font-weight-bold: 700; +} + +/* Base styles */ +body, html { + margin: 0; + padding: 0; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; + font-weight: var(--font-weight-normal); + background-color: var(--background-color); + color: var(--text-color); + width: 100vw; + overflow-x: hidden; +} + +* { + box-sizing: border-box; +} + +/* Force full width layout */ +#main-interface, +.gradio-app, +.gradio-container { + width: 100vw !important; + max-width: 100vw !important; + margin: 0 !important; + padding: 0 !important; + box-shadow: none !important; + border: none !important; + overflow-x: hidden !important; +} + +/* Header styling */ +#header-row { + background: white; + padding: 15px 20px; + margin-bottom: 20px; + box-shadow: var(--shadow-sm); + border-bottom: 1px solid rgba(0,0,0,0.05); +} + +.header-container { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + padding: 10px 0; +} + +.app-header { + display: flex; + align-items: center; + gap: 12px; + margin-bottom: 15px; +} + +.app-header .emoji { + font-size: 36px; +} + +/* Fix for Colorize Diffusion title visibility */ +.gradio-markdown h1, +.gradio-markdown h2, +#header-row h1, +#header-row h2, +.title-text, +.app-header .title-text { + display: inline-block !important; + visibility: visible !important; + opacity: 1 !important; + position: relative !important; + color: var(--primary-color) !important; + font-size: 32px !important; + font-weight: 800 !important; +} + +/* Badge links under the header */ +.paper-links-icons { + display: flex; + flex-wrap: wrap; + justify-content: center; + gap: 8px; + margin-top: 5px; +} + +.paper-links-icons a { + transition: transform 0.2s ease; + opacity: 0.9; +} + +.paper-links-icons a:hover { + transform: translateY(-3px); + opacity: 1; +} + +/* Content layout */ +#content-row { + padding: 0 20px 20px 20px; + max-width: 100%; + margin: 0 auto; +} + +/* Apply bold font to all text elements for better readability */ +p, span, label, button, input, textarea, select, .gradio-button, .gradio-checkbox, .gradio-dropdown, .gradio-textbox { + font-weight: var(--font-weight-normal); +} + +/* Make headings bolder */ +h1, h2, h3, h4, h5, h6 { + font-weight: var(--font-weight-bold); +} + +/* Improved font styling for Gradio UI elements */ +.gradio-container, +.gradio-container *, +.gradio-app, +.gradio-app * { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 500 !important; +} + +/* Style for labels and slider labels */ +.gradio-container label, +.gradio-slider label, +.gradio-checkbox label, +.gradio-radio label, +.gradio-dropdown label, +.gradio-textbox label, +.gradio-number label, +.gradio-button, +.gradio-checkbox span, +.gradio-radio span { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 600 !important; + letter-spacing: 0.01em; +} + +/* Style for buttons */ +button, +.gradio-button { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 600 !important; +} + +/* Style for input values */ +input, +textarea, +select, +.gradio-textbox textarea, +.gradio-number input { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 500 !important; +} + +/* Better styling for drop areas */ +.upload-box, +[data-testid="image"] { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 500 !important; +} + +/* Additional styling for values in sliders and numbers */ +.wrap .wrap .wrap span { + font-family: 'Roboto', 'Segoe UI', system-ui, -apple-system, sans-serif !important; + font-weight: 600 !important; +} diff --git a/configs/inference/sdxl.yaml b/configs/inference/sdxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc9d373f7428ab9c8567083f3c4ae215e70e5cc4 --- /dev/null +++ b/configs/inference/sdxl.yaml @@ -0,0 +1,88 @@ +model: + base_learning_rate: 1.0e-6 + target: refnet.models.colorizerXL.InferenceWrapper + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + image_size: 128 + channels: 4 + scale_factor: 0.13025 + logits_embed: false + + unet_config: + target: refnet.modules.unet.DualCondUNetXL + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + adm_in_channels: 512 +# adm_in_channels: 2816 + num_classes: sequential + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: true + use_linear_in_transformer: true + transformer_depth: [1, 2, 10] + context_dim: 2048 + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 512 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + cond_stage_config: + target: refnet.modules.embedder.HFCLIPVisionModel + # target: refnet.modules.embedder.FrozenOpenCLIPImageEmbedder + params: + arch: ViT-bigG-14 + + control_encoder_config: +# target: refnet.modules.encoder.MultiEncoder + target: refnet.modules.encoder.MultiScaleAttentionEncoder + params: + in_ch: 3 + model_channels: 320 + ch_mults: [ 1, 2, 4 ] + + img_embedder_config: + target: refnet.modules.embedder.WDv14SwinTransformerV2 + + scalar_embedder_config: + target: refnet.modules.embedder.TimestepEmbedding + params: + embed_dim: 256 + + proj_config: + target: refnet.modules.proj.ClusterConcat +# target: refnet.modules.proj.RecoveryClusterConcat + params: + input_dim: 1280 + c_dim: 1024 + output_dim: 2048 + token_length: 196 + dim_head: 128 +# proj_config: +# target: refnet.modules.proj.LogitClusterConcat +# params: +# input_dim: 1280 +# c_dim: 1024 +# output_dim: 2048 +# token_length: 196 +# dim_head: 128 +# mlp_in_dim: 9083 +# mlp_ckpt_path: pretrained_models/proj.safetensors \ No newline at end of file diff --git a/configs/inference/xlv2.yaml b/configs/inference/xlv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..035b09df80ad2082bef3a352a399ea98b706ab19 --- /dev/null +++ b/configs/inference/xlv2.yaml @@ -0,0 +1,108 @@ +model: + base_learning_rate: 1.0e-6 + target: refnet.models.v2-colorizerXL.InferenceWrapperXL + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + image_size: 128 + channels: 4 + scale_factor: 0.13025 + controller: true + + unet_config: + target: refnet.modules.unet.DualCondUNetXL + params: + use_checkpoint: True + in_channels: 4 + in_channels_fg: 4 + out_channels: 4 + model_channels: 320 + adm_in_channels: 512 + num_classes: sequential + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: true + use_linear_in_transformer: true + transformer_depth: [1, 2, 10] + context_dim: 2048 + map_module: false + warp_module: false + style_modulation: false + + bg_encoder_config: + target: refnet.modules.unet.ReferenceNet + params: + use_checkpoint: True + in_channels: 6 + model_channels: 320 + adm_in_channels: 1024 + num_classes: sequential + attention_resolutions: [ 4, 2 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4 ] + num_head_channels: 64 + use_spatial_transformer: true + use_linear_in_transformer: true + disable_cross_attentions: true + context_dim: 2048 + transformer_depth: [ 1, 2, 10 ] + + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 512 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + cond_stage_config: + target: refnet.modules.embedder.HFCLIPVisionModel + params: + arch: ViT-bigG-14 + + img_embedder_config: + target: refnet.modules.embedder.WDv14SwinTransformerV2 + + control_encoder_config: + target: refnet.modules.encoder.MultiScaleAttentionEncoder + params: + in_ch: 3 + model_channels: 320 + ch_mults: [1, 2, 4] + + proj_config: + target: refnet.modules.proj.ClusterConcat + # target: refnet.modules.proj.RecoveryClusterConcat + params: + input_dim: 1280 + c_dim: 1024 + output_dim: 2048 + token_length: 196 + dim_head: 128 + + scalar_embedder_config: + target: refnet.modules.embedder.TimestepEmbedding + params: + embed_dim: 256 + + lora_config: + lora_params: [ + { + label: background, + root_module: model.diffusion_model, + target_keys: [ attn2.to_q, attn2.to_k, attn2.to_v ], + r: 4, + } + ] \ No newline at end of file diff --git a/configs/scheduler_cfgs/ddim.yaml b/configs/scheduler_cfgs/ddim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb3d92393b590034df264ebf4e4555c3d5e4492a --- /dev/null +++ b/configs/scheduler_cfgs/ddim.yaml @@ -0,0 +1,10 @@ +beta_start: 0.00085 +beta_end: 0.012 +beta_schedule: "scaled_linear" +clip_sample: false +steps_offset: 1 + +### Zero-SNR params +#rescale_betas_zero_snr: True +#timestep_spacing: "trailing" +timestep_spacing: "leading" \ No newline at end of file diff --git a/configs/scheduler_cfgs/dpm.yaml b/configs/scheduler_cfgs/dpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9410a2c84583c3310378ee0d9a32bda9b97c1f63 --- /dev/null +++ b/configs/scheduler_cfgs/dpm.yaml @@ -0,0 +1,8 @@ +beta_start: 0.00085 +beta_end: 0.012 +beta_schedule: "scaled_linear" +steps_offset: 1 + +### Zero-SNR params +#rescale_betas_zero_snr: True +timestep_spacing: "leading" \ No newline at end of file diff --git a/configs/scheduler_cfgs/dpm_sde.yaml b/configs/scheduler_cfgs/dpm_sde.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f383f8d6dba924c49a6a2f0717d3b93926db3071 --- /dev/null +++ b/configs/scheduler_cfgs/dpm_sde.yaml @@ -0,0 +1,9 @@ +beta_start: 0.00085 +beta_end: 0.012 +beta_schedule: "scaled_linear" +steps_offset: 1 + +### Zero-SNR params +#rescale_betas_zero_snr: True +timestep_spacing: "leading" +algorithm_type: sde-dpmsolver++ \ No newline at end of file diff --git a/configs/scheduler_cfgs/lms.yaml b/configs/scheduler_cfgs/lms.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26950c26cca70bfb186b161aec77ab85118882b7 --- /dev/null +++ b/configs/scheduler_cfgs/lms.yaml @@ -0,0 +1,9 @@ +beta_start: 0.00085 +beta_end: 0.012 +beta_schedule: "scaled_linear" +#clip_sample: false +steps_offset: 1 + +### Zero-SNR params +#rescale_betas_zero_snr: True +timestep_spacing: "leading" \ No newline at end of file diff --git a/configs/scheduler_cfgs/pndm.yaml b/configs/scheduler_cfgs/pndm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8342ca35f07d2f66ba9429f7283be86db0b82f7 --- /dev/null +++ b/configs/scheduler_cfgs/pndm.yaml @@ -0,0 +1,10 @@ +beta_start: 0.00085 +beta_end: 0.012 +beta_schedule: "scaled_linear" +#clip_sample: false +steps_offset: 1 + +### Zero-SNR params +#rescale_betas_zero_snr: True +#timestep_spacing: "trailing" +timestep_spacing: "leading" \ No newline at end of file diff --git a/k_diffusion/__init__.py b/k_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..029a4a219b0259c183d177517a44fc0c0582dfe0 --- /dev/null +++ b/k_diffusion/__init__.py @@ -0,0 +1,8 @@ +from .sampling import * + + +def create_noise_sampler(x, sigmas, seed): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) \ No newline at end of file diff --git a/k_diffusion/external.py b/k_diffusion/external.py new file mode 100644 index 0000000000000000000000000000000000000000..18b0fc6d317044baeb0a7bedc3da9d46189f538e --- /dev/null +++ b/k_diffusion/external.py @@ -0,0 +1,181 @@ +import math + +import torch +from torch import nn + +from . import sampling, utils + + +class VDenoiser(nn.Module): + """A v-diffusion-pytorch model wrapper for k-diffusion.""" + + def __init__(self, inner_model): + super().__init__() + self.inner_model = inner_model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def sigma_to_t(self, sigma): + return sigma.atan() / math.pi * 2 + + def t_to_sigma(self, t): + return (t * math.pi / 2).tan() + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + self.quantize = quantize + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def get_sigmas(self, n=None): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return sampling.append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + if quantize: + return dists.abs().argmin(dim=0).view(sigma.shape) + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + return (eps - noise).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out + + +class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for OpenAI diffusion models.""" + + def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): + alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_eps(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + self.sigmas = self.sigmas.to(device) + self.log_sigmas = self.log_sigmas.to(device) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + +class DiscreteVDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output v.""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def get_v(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class CompVisVDenoiser(DiscreteVDDPMDenoiser): + """A wrapper for CompVis diffusion models that output v.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + self.sigmas = self.sigmas.to(device) + self.log_sigmas = self.log_sigmas.to(device) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..f842d3029e1ec65a29c7758de9785e6839cd86fc --- /dev/null +++ b/k_diffusion/sampling.py @@ -0,0 +1,702 @@ +import math + +from scipy import integrate +import torch +from torch import nn +from torchdiffeq import odeint +import torchsde +from tqdm.auto import trange, tqdm + +from . import utils + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n).to(device) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): + """Constructs an polynomial in log sigma noise schedule.""" + ramp = torch.linspace(1, 0, n, device=device) ** rho + sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / utils.append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0. + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigmas_cpu = sigmas.detach().cpu().numpy() + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +class PIDStepSizeController: + """A PID controller for ODE adaptive step size control.""" + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + self.h = h + self.b1 = (pcoeff + icoeff + dcoeff) / order + self.b2 = -(pcoeff + 2 * dcoeff) / order + self.b3 = dcoeff / order + self.accept_safety = accept_safety + self.eps = eps + self.errs = [] + + def limiter(self, x): + return 1 + math.atan(x - 1) + + def propose_step(self, error): + inv_error = 1 / (float(error) + self.eps) + if not self.errs: + self.errs = [inv_error, inv_error, inv_error] + self.errs[0] = inv_error + factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = self.limiter(factor) + accept = factor >= self.accept_safety + if accept: + self.errs[2] = self.errs[1] + self.errs[1] = self.errs[0] + self.h *= factor + return accept + + +class DPMSolver(nn.Module): + """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + + def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + super().__init__() + self.model = model + self.extra_args = {} if extra_args is None else extra_args + self.eps_callback = eps_callback + self.info_callback = info_callback + + def t(self, sigma): + return -sigma.log() + + def sigma(self, t): + return t.neg().exp() + + def eps(self, eps_cache, key, x, t, *args, **kwargs): + if key in eps_cache: + return eps_cache[key], eps_cache + sigma = self.sigma(t) * x.new_ones([x.shape[0]]) + eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + if self.eps_callback is not None: + self.eps_callback() + return eps, {key: eps, **eps_cache} + + def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + x_1 = x - self.sigma(t_next) * h.expm1() * eps + return x_1, eps_cache + + def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + return x_2, eps_cache + + def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + s2 = t + r2 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) + eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) + x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + return x_3, eps_cache + + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if not t_end > t_start and eta: + raise ValueError('eta must be 0 for reverse sampling') + + m = math.floor(nfe / 3) + 1 + ts = torch.linspace(t_start, t_end, m + 1, device=x.device) + + if nfe % 3 == 0: + orders = [3] * (m - 2) + [2, 1] + else: + orders = [3] * (m - 1) + [nfe % 3] + + for i in range(len(orders)): + eps_cache = {} + t, t_next = ts[i], ts[i + 1] + if eta: + sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) + t_next_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 + else: + t_next_, su = t_next, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + denoised = x - self.sigma(t) * eps + if self.info_callback is not None: + self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + + if orders[i] == 1: + x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + elif orders[i] == 2: + x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + else: + x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + + x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) + + return x + + def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if order not in {2, 3}: + raise ValueError('order should be 2 or 3') + forward = t_end > t_start + if not forward and eta: + raise ValueError('eta must be 0 for reverse sampling') + h_init = abs(h_init) * (1 if forward else -1) + atol = torch.tensor(atol) + rtol = torch.tensor(rtol) + s = t_start + x_prev = x + accept = True + pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) + info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + + while s < t_end - 1e-5 if forward else s > t_end + 1e-5: + eps_cache = {} + t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + if eta: + sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) + t_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 + else: + t_, su = t, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + denoised = x - self.sigma(s) * eps + + if order == 2: + x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + else: + x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) + error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 + accept = pid.propose_step(error) + if accept: + x_prev = x_low + x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) + s = t + info['n_accept'] += 1 + else: + info['n_reject'] += 1 + info['nfe'] += order + info['steps'] += 1 + + if self.info_callback is not None: + self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + + return x, info + + +@torch.no_grad() +def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): + """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(total=n, disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) + + +@torch.no_grad() +def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): + """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) + if return_info: + return x, info + return x + + +@torch.no_grad() +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigma_down == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + # Noise addition + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + return x + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x + + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +@torch.no_grad() +def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """DPM-Solver++(3M) SDE.""" + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + denoised_1, denoised_2 = None, None + h_1, h_2 = None, None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + h_eta = h * (eta + 1) + + x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + + if h_2 is not None: + r0 = h_1 / h + r1 = h_2 / h + d1_0 = (denoised - denoised_1) / r0 + d1_1 = (denoised_1 - denoised_2) / r1 + d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + d2 = (d1_0 - d1_1) / (r0 + r1) + phi_2 = h_eta.neg().expm1() / h_eta + 1 + phi_3 = phi_2 / h_eta - 0.5 + x = x + phi_2 * d1 - phi_3 * d2 + elif h_1 is not None: + r = h_1 / h + d = (denoised - denoised_1) / r + phi_2 = h_eta.neg().expm1() / h_eta + 1 + x = x + phi_2 * d + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + + denoised_1, denoised_2 = denoised, denoised_1 + h_1, h_2 = h, h_1 + return x diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da857b4a8e41af93af9d010542e9246998d95d2b --- /dev/null +++ b/k_diffusion/utils.py @@ -0,0 +1,457 @@ +from contextlib import contextmanager +import hashlib +import math +from pathlib import Path +import shutil +import threading +import urllib +import warnings + +from PIL import Image +import safetensors +import torch +from torch import nn, optim +from torch.utils import data +from torchvision.transforms import functional as TF + + +def from_pil_image(x): + """Converts from a PIL image to a tensor.""" + x = TF.to_tensor(x) + if x.ndim == 2: + x = x[..., None] + return x * 2 - 1 + + +def to_pil_image(x): + """Converts from a tensor to a PIL image.""" + if x.ndim == 4: + assert x.shape[0] == 1 + x = x[0] + if x.shape[0] == 1: + x = x[0] + return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) + + +def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): + """Apply passed in transforms for HuggingFace Datasets.""" + images = [transform(image.convert(mode)) for image in examples[image_key]] + return {image_key: images} + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def n_params(module): + """Returns the number of trainable parameters in a module.""" + return sum(p.numel() for p in module.parameters()) + + +def download_file(path, url, digest=None): + """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + with urllib.request.urlopen(url) as response, open(path, 'wb') as f: + shutil.copyfileobj(response, f) + if digest is not None: + file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if digest != file_digest: + raise OSError(f'hash of {path} (url: {url}) failed to validate') + return path + + +@contextmanager +def train_mode(model, mode=True): + """A context manager that places a model into training mode and restores + the previous mode on exit.""" + modes = [module.training for module in model.modules()] + try: + yield model.train(mode) + finally: + for i, module in enumerate(model.modules()): + module.training = modes[i] + + +def eval_mode(model): + """A context manager that places a model into evaluation mode and restores + the previous mode on exit.""" + return train_mode(model, False) + + +@torch.no_grad() +def ema_update(model, averaged_model, decay): + """Incorporates updated model parameters into an exponential moving averaged + version of a model. It should be called after each optimizer step.""" + model_params = dict(model.named_parameters()) + averaged_params = dict(averaged_model.named_parameters()) + assert model_params.keys() == averaged_params.keys() + + for name, param in model_params.items(): + averaged_params[name].lerp_(param, 1 - decay) + + model_buffers = dict(model.named_buffers()) + averaged_buffers = dict(averaged_model.named_buffers()) + assert model_buffers.keys() == averaged_buffers.keys() + + for name, buf in model_buffers.items(): + averaged_buffers[name].copy_(buf) + + +class EMAWarmup: + """Implements an EMA warmup using an inverse decay schedule. + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 1. + min_value (float): The minimum EMA decay rate. Default: 0. + max_value (float): The maximum EMA decay rate. Default: 1. + start_at (int): The epoch to start averaging at. Default: 0. + last_epoch (int): The index of last epoch. Default: 0. + """ + + def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, + last_epoch=0): + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + self.start_at = start_at + self.last_epoch = last_epoch + + def state_dict(self): + """Returns the state of the class as a :class:`dict`.""" + return dict(self.__dict__.items()) + + def load_state_dict(self, state_dict): + """Loads the class's state. + Args: + state_dict (dict): scaler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_value(self): + """Gets the current EMA decay rate.""" + epoch = max(0, self.last_epoch - self.start_at) + value = 1 - (1 + epoch / self.inv_gamma) ** -self.power + return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) + + def step(self): + """Updates the step count.""" + self.last_epoch += 1 + + +class InverseLR(optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +class ExponentialLR(optim.lr_scheduler._LRScheduler): + """Implements an exponential learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate + continuously by decay (default 0.5) every num_steps steps. + Args: + optimizer (Optimizer): Wrapped optimizer. + num_steps (float): The number of steps to decay the learning rate by decay in. + decay (float): The factor by which to decay the learning rate every num_steps + steps. Default: 0.5. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.num_steps = num_steps + self.decay = decay + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler): + """Implements a constant learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False): + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + return [warmup * base_lr for base_lr in self.base_lrs] + + +def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None): + """Draws stratified samples from a uniform distribution.""" + if groups <= 0: + raise ValueError(f"groups must be positive, got {groups}") + if group < 0 or group >= groups: + raise ValueError(f"group must be in [0, {groups})") + n = shape[-1] * groups + offsets = torch.arange(group, n, groups, dtype=dtype, device=device) + u = torch.rand(shape, dtype=dtype, device=device) + return (offsets + u) / n + + +stratified_settings = threading.local() + + +@contextmanager +def enable_stratified(group=0, groups=1, disable=False): + """A context manager that enables stratified sampling.""" + try: + stratified_settings.disable = disable + stratified_settings.group = group + stratified_settings.groups = groups + yield + finally: + del stratified_settings.disable + del stratified_settings.group + del stratified_settings.groups + + +@contextmanager +def enable_stratified_accelerate(accelerator, disable=False): + """A context manager that enables stratified sampling, distributing the strata across + all processes and gradient accumulation steps using settings from Hugging Face Accelerate.""" + try: + rank = accelerator.process_index + world_size = accelerator.num_processes + acc_steps = accelerator.gradient_state.num_steps + acc_step = accelerator.step % acc_steps + group = rank * acc_steps + acc_step + groups = world_size * acc_steps + with enable_stratified(group, groups, disable=disable): + yield + finally: + pass + + +def stratified_with_settings(shape, dtype=None, device=None): + """Draws stratified samples from a uniform distribution, using settings from a context + manager.""" + if not hasattr(stratified_settings, 'disable') or stratified_settings.disable: + return torch.rand(shape, dtype=dtype, device=device) + return stratified_uniform( + shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device + ) + + +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7 + return torch.distributions.Normal(loc, scale).icdf(u).exp() + + +def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from an optionally truncated log-logistic distribution.""" + min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) + max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) + min_cdf = min_value.log().sub(loc).div(scale).sigmoid() + max_cdf = max_value.log().sub(loc).div(scale).sigmoid() + u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf + return u.logit().mul(scale).add(loc).exp().to(dtype) + + +def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): + """Draws samples from an log-uniform distribution.""" + min_value = math.log(min_value) + max_value = math.log(max_value) + return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + + +def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from a truncated v-diffusion training timestep distribution.""" + min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi + max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi + u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf + return torch.tan(u * math.pi / 2) * sigma_data + + +def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32): + """Draws samples from an interpolated cosine timestep distribution (from simple diffusion).""" + + def logsnr_schedule_cosine(t, logsnr_min, logsnr_max): + t_min = math.atan(math.exp(-0.5 * logsnr_max)) + t_max = math.atan(math.exp(-0.5 * logsnr_min)) + return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + + def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max): + shift = 2 * math.log(noise_d / image_d) + return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift + + def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max): + logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max) + logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max) + return torch.lerp(logsnr_low, logsnr_high, t) + + logsnr_min = -2 * math.log(min_value / sigma_data) + logsnr_max = -2 * math.log(max_value / sigma_data) + u = stratified_with_settings(shape, device=device, dtype=dtype) + logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max) + return torch.exp(-logsnr / 2) * sigma_data + + +def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): + """Draws samples from a split lognormal distribution.""" + n = torch.randn(shape, device=device, dtype=dtype).abs() + u = torch.rand(shape, device=device, dtype=dtype) + n_left = n * -scale_1 + loc + n_right = n * scale_2 + loc + ratio = scale_1 / (scale_1 + scale_2) + return torch.where(u < ratio, n_left, n_right).exp() + + +class FolderOfImages(data.Dataset): + """Recursively finds all images in a directory. It does not support + classes/targets.""" + + IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} + + def __init__(self, root, transform=None): + super().__init__() + self.root = Path(root) + self.transform = nn.Identity() if transform is None else transform + self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) + + def __repr__(self): + return f'FolderOfImages(root="{self.root}", len: {len(self)})' + + def __len__(self): + return len(self.paths) + + def __getitem__(self, key): + path = self.paths[key] + with open(path, 'rb') as f: + image = Image.open(f).convert('RGB') + image = self.transform(image) + return image, + + +class CSVLogger: + def __init__(self, filename, columns): + self.filename = Path(filename) + self.columns = columns + if self.filename.exists(): + self.file = open(self.filename, 'a') + else: + self.file = open(self.filename, 'w') + self.write(*self.columns) + + def write(self, *args): + print(*args, sep=',', file=self.file, flush=True) + + +@contextmanager +def tf32_mode(cudnn=None, matmul=None): + """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" + cudnn_old = torch.backends.cudnn.allow_tf32 + matmul_old = torch.backends.cuda.matmul.allow_tf32 + try: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul + yield + finally: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn_old + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul_old + + +def get_safetensors_metadata(path): + """Retrieves the metadata from a safetensors file.""" + return safetensors.safe_open(path, "pt").metadata() + + +def ema_update_dict(values, updates, decay): + for k, v in updates.items(): + if k not in values: + values[k] = v + else: + values[k] *= decay + values[k] += (1 - decay) * v + return values diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1dd07cee6ca2252f61b72e6425d60880366cb2 --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,488 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange +from typing import Optional, Any + +from refnet.util import checkpoint_wrapper, default + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True + attn_processor = xformers.ops.memory_efficient_attention +except: + XFORMERS_IS_AVAILBLE = False + attn_processor = F.scaled_dot_product_attention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + @checkpoint_wrapper + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels, head_dim=None): + super().__init__() + self.in_channels = in_channels + self.head_dim = default(head_dim, in_channels) + self.heads = in_channels // self.head_dim + # if self.head_dim > 256: + # self.attn_processor = F.scaled_dot_product_attention + # else: + self.attn_processor = attn_processor + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, -1, self.heads, C) + .permute(0, 2, 1, 3) + .reshape(B * self.heads, -1, C) + .contiguous(), + (q, k, v), + ) + out = self.attn_processor(q, k, v) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +def make_attn(in_channels, **kwargs): + return MemoryEfficientAttnBlock(in_channels) + + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + checkpoint=True, **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", checkpoint=True, **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..507b82a4fd8d177727c964a4b0217bd6766ab106 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/preprocessor/__init__.py b/preprocessor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f74f96424781101eb41633498af68911c3d877 --- /dev/null +++ b/preprocessor/__init__.py @@ -0,0 +1,124 @@ +import os + +import torch.hub +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as tf +import functools + +model_path = "preprocessor/weights" +os.environ["HF_HOME"] = model_path +torch.hub.set_dir(model_path) + +from torch.hub import download_url_to_file +from transformers import AutoModelForImageSegmentation +from .anime2sketch import UnetGenerator +from .manga_line_extractor import res_skip +from .sketchKeras import SketchKeras +from .sk_model import LineartDetector +from .anime_segment import ISNetDIS +from refnet.util import load_weights + + +class NoneMaskExtractor(nn.Module): + def __init__(self): + super().__init__() + self.identity = nn.Identity() + + def proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs): + b, c, h, w = x.shape + return torch.zeros([b, 1, h, w], device=x.device) + + def forward(self, x): + return self.proceed(x) + + +remote_model_dict = { + "lineart": "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth", + "lineart_denoise": "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth", + "lineart_keras": "https://huggingface.co/tellurion/line_extractor/resolve/main/model.pth", + "lineart_sk": "https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth", + "ISNet": "https://huggingface.co/tellurion/line_extractor/resolve/main/isnetis.safetensors", + "ISNet-sketch": "https://huggingface.co/tellurion/line_extractor/resolve/main/sketch-segment.safetensors" +} + +BiRefNet_dict = { + "rmbg-v2": ("briaai/RMBG-2.0", 1024), + "BiRefNet": ("ZhengPeng7/BiRefNet", 1024), + "BiRefNet_HR": ("ZhengPeng7/BiRefNet_HR", 2048) +} + +def rmbg_proceed(self, x: torch.Tensor, th=None, tw=None, dilate=False, *args, **kwargs): + b, c, h, w = x.shape + x = (x + 1.0) / 2. + x = tf.resize(x, [self.image_size, self.image_size]) + x = tf.normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + x = self(x)[-1].sigmoid() + x = tf.resize(x, [h, w]) + + if th and tw: + x = tf.pad(x, padding=[(th-h)//2, (tw-w)//2]) + if dilate: + x = F.max_pool2d(x, kernel_size=21, stride=1, padding=10) + # x = F.max_pool2d(x, kernel_size=11, stride=1, padding=5) + # x = mask_expansion(x, 60, 40) + x = torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x)) + x = x.clamp(0, 1) + return x + + + +def create_model(model_name="lineart"): + """Create a model for anime2sketch + hardcoding the options for simplicity + """ + if model_name == "none": + return NoneMaskExtractor().eval() + + if model_name in BiRefNet_dict.keys(): + model = AutoModelForImageSegmentation.from_pretrained( + BiRefNet_dict[model_name][0], + trust_remote_code = True, + cache_dir = model_path, + device_map = None, + low_cpu_mem_usage = False, + ) + model.eval() + model.image_size = BiRefNet_dict[model_name][1] + model.proceed = rmbg_proceed.__get__(model, model.__class__) + return model + + assert model_name in remote_model_dict.keys() + remote_path = remote_model_dict[model_name] + basename = os.path.basename(remote_path) + ckpt_path = os.path.join(model_path, basename) + + if not os.path.exists(model_path): + os.makedirs(model_path) + + if not os.path.exists(ckpt_path): + cache_path = "preprocessor/weights/weights.tmp" + download_url_to_file(remote_path, dst=cache_path) + os.rename(cache_path, ckpt_path) + + if model_name == "lineart": + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) + elif model_name == "lineart_denoise": + model = res_skip() + elif model_name == "lineart_keras": + model = SketchKeras() + elif model_name == "lineart_sk": + model = LineartDetector() + elif model_name == "ISNet" or model_name == "ISNet-sketch": + model = ISNetDIS() + else: + return None + + ckpt = load_weights(ckpt_path) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + model.load_state_dict(ckpt) + return model.eval() \ No newline at end of file diff --git a/preprocessor/anime2sketch.py b/preprocessor/anime2sketch.py new file mode 100644 index 0000000000000000000000000000000000000000..56ad7fe1eb1bdf00e8a9aab89df86bf2c0316d3b --- /dev/null +++ b/preprocessor/anime2sketch.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import functools +import torchvision.transforms as transforms + +""" + Anime2Sketch: A sketch extractor for illustration, anime art, manga + Author: Xiaoyu Zhang + Github link: https://github.com/Mukosame/Anime2Sketch +""" + +def to_tensor(x, inverse=False): + x = transforms.ToTensor()(x).unsqueeze(0) + x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda() + return x if not inverse else -x + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + def proceed(self, img): + sketch = self(to_tensor(img)) + return -sketch + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x).clamp(-1, 1) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) \ No newline at end of file diff --git a/preprocessor/anime_segment.py b/preprocessor/anime_segment.py new file mode 100644 index 0000000000000000000000000000000000000000..89aafadffce5e9e06786d0c32f6e2ddbed25f2fa --- /dev/null +++ b/preprocessor/anime_segment.py @@ -0,0 +1,487 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from refnet.util import default + +""" + Source code: https://github.com/SkyTNT/anime-segmentation?tab=readme-ov-file + Author: SkyTNT +""" + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__( + self, + in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super(myrebnconv, self).__init__() + + self.conv = nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self, x): + return self.rl(self.bn(self.conv(x))) + + +class ISNetDIS(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(ISNetDIS, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + + # d2 = self.side2(hx2d) + # d2 = _upsample_like(d2, x) + # + # d3 = self.side3(hx3d) + # d3 = _upsample_like(d3, x) + # + # d4 = self.side4(hx4d) + # d4 = _upsample_like(d4, x) + # + # d5 = self.side5(hx5d) + # d5 = _upsample_like(d5, x) + # + # d6 = self.side6(hx6) + # d6 = _upsample_like(d6, x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + # + # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] + # return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] + return torch.sigmoid(d1) + + def proceed(self, x: torch.Tensor, th=None, tw=None, s=1024, dilate=False, crop=True, *args, **kwargs): + b, c, h, w = x.shape + + if crop: + th, tw = default(th, h), default(tw, w) + scale = s / max(h, w) + h, w = int(h * scale), int(w * scale) + + canvas = -torch.ones((b, c, s, s), dtype=x.dtype, device=x.device) + ph, pw = (s - h) // 2, (s - w) // 2 + x = F.interpolate(x, scale_factor=scale, mode="bicubic") + + canvas[:, :, ph: ph+h, pw: pw+w] = x + + canvas = 1 - (canvas + 1.) / 2. + mask = self(canvas)[:, :, ph: ph+h, pw: pw+w] + + else: + x = F.interpolate(x, size=(s, s), mode="bicubic") + mask = self(x) + + mask = F.interpolate(mask, (th, tw), mode="bicubic").clamp(0, 1) + + if dilate: + mask = F.max_pool2d(mask, kernel_size=21, stride=1, padding=10) + # mask = mask_expansion(mask, 32, 20) + return mask \ No newline at end of file diff --git a/preprocessor/manga_line_extractor.py b/preprocessor/manga_line_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d80f3f2217cec5cafe210de67519da4370fdef --- /dev/null +++ b/preprocessor/manga_line_extractor.py @@ -0,0 +1,187 @@ +import torch.nn as nn +import torchvision.transforms as transforms + + +class _bn_relu_conv(nn.Module): + def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): + super(_bn_relu_conv, self).__init__() + self.model = nn.Sequential( + nn.BatchNorm2d(in_filters, eps=1e-3), + nn.LeakyReLU(0.2), + nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') + ) + + def forward(self, x): + return self.model(x) + + +class _u_bn_relu_conv(nn.Module): + def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): + super(_u_bn_relu_conv, self).__init__() + self.model = nn.Sequential( + nn.BatchNorm2d(in_filters, eps=1e-3), + nn.LeakyReLU(0.2), + nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), + nn.Upsample(scale_factor=2, mode='nearest') + ) + + def forward(self, x): + return self.model(x) + + + +class _shortcut(nn.Module): + def __init__(self, in_filters, nb_filters, subsample=1): + super(_shortcut, self).__init__() + self.process = False + self.model = None + if in_filters != nb_filters or subsample != 1: + self.process = True + self.model = nn.Sequential( + nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) + ) + + def forward(self, x, y): + #print(x.size(), y.size(), self.process) + if self.process: + y0 = self.model(x) + #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape) + return y0 + y + else: + #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape) + return x + y + +class _u_shortcut(nn.Module): + def __init__(self, in_filters, nb_filters, subsample): + super(_u_shortcut, self).__init__() + self.process = False + self.model = None + if in_filters != nb_filters: + self.process = True + self.model = nn.Sequential( + nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), + nn.Upsample(scale_factor=2, mode='nearest') + ) + + def forward(self, x, y): + if self.process: + return self.model(x) + y + else: + return x + y + + +class basic_block(nn.Module): + def __init__(self, in_filters, nb_filters, init_subsample=1): + super(basic_block, self).__init__() + self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) + self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) + self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.residual(x1) + return self.shortcut(x, x2) + +class _u_basic_block(nn.Module): + def __init__(self, in_filters, nb_filters, init_subsample=1): + super(_u_basic_block, self).__init__() + self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) + self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) + self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) + + def forward(self, x): + y = self.residual(self.conv1(x)) + return self.shortcut(x, y) + + +class _residual_block(nn.Module): + def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): + super(_residual_block, self).__init__() + layers = [] + for i in range(repetitions): + init_subsample = 1 + if i == repetitions - 1 and not is_first_layer: + init_subsample = 2 + if i == 0: + l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) + else: + l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) + layers.append(l) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class _upsampling_residual_block(nn.Module): + def __init__(self, in_filters, nb_filters, repetitions): + super(_upsampling_residual_block, self).__init__() + layers = [] + for i in range(repetitions): + l = None + if i == 0: + l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input) + else: + l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input) + layers.append(l) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + +class res_skip(nn.Module): + + def __init__(self): + super(res_skip, self).__init__() + self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) # (input) + self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) # (block0) + self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) # (block1) + self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) # (block2) + self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) # (block3) + + self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) # (block4) + self.res1 = _shortcut(in_filters=192, nb_filters=192) # (block3, block5, subsample=(1,1)) + + self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) # (res1) + self.res2 = _shortcut(in_filters=96, nb_filters=96) # (block2, block6, subsample=(1,1)) + + self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) # (res2) + self.res3 = _shortcut(in_filters=48, nb_filters=48) # (block1, block7, subsample=(1,1)) + + self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) # (res3) + self.res4 = _shortcut(in_filters=24, nb_filters=24) # (block0,block8, subsample=(1,1)) + + self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) # (res4) + self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) # (block7) + + def forward(self, x): + x0 = self.block0(x) + x1 = self.block1(x0) + x2 = self.block2(x1) + x3 = self.block3(x2) + x4 = self.block4(x3) + + x5 = self.block5(x4) + res1 = self.res1(x3, x5) + + x6 = self.block6(res1) + res2 = self.res2(x2, x6) + + x7 = self.block7(res2) + res3 = self.res3(x1, x7) + + x8 = self.block8(res3) + res4 = self.res4(x0, x8) + + x9 = self.block9(res4) + y = self.conv15(x9) + + return y + + def proceed(self, sketch): + sketch = transforms.ToTensor()(sketch).unsqueeze(0)[:, 0] * 255 + sketch = sketch.unsqueeze(1).cuda() + sketch = self(sketch) / 127.5 - 1 + return -sketch.clamp(-1, 1) \ No newline at end of file diff --git a/preprocessor/sk_model.py b/preprocessor/sk_model.py new file mode 100644 index 0000000000000000000000000000000000000000..11cb01017d19f920973127e613e28d033b15f583 --- /dev/null +++ b/preprocessor/sk_model.py @@ -0,0 +1,94 @@ +import torch.nn as nn +import torchvision.transforms.functional as tf + + +norm_layer = nn.InstanceNorm2d + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [ nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features*2 + for _ in range(2): + model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features*2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features//2 + for _ in range(2): + model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features//2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [ nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class LineartDetector(nn.Module): + def __init__(self): + super().__init__() + self.model = Generator(3, 1, 3) + + def load_state_dict(self, sd): + self.model.load_state_dict(sd) + + def proceed(self, sketch): + sketch = tf.pil_to_tensor(sketch).unsqueeze(0).cuda().float() + return -self.model(sketch) \ No newline at end of file diff --git a/preprocessor/sketchKeras.py b/preprocessor/sketchKeras.py new file mode 100644 index 0000000000000000000000000000000000000000..e435f7a636c1a080e809fb0a2f3ffb48682afe1b --- /dev/null +++ b/preprocessor/sketchKeras.py @@ -0,0 +1,153 @@ +import cv2 +import numpy as np + +import torch +import torch.nn as nn + + +def postprocess(pred, thresh=0.18): + assert thresh <= 1.0 and thresh >= 0.0 + + pred = torch.amax(pred, 0) + pred[pred < thresh] = 0 + pred -= 0.5 + pred *= 2 + return pred + + +class SketchKeras(nn.Module): + def __init__(self): + super(SketchKeras, self).__init__() + + self.downblock_1 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(1, 32, kernel_size=3, stride=1), + nn.BatchNorm2d(32, eps=1e-3, momentum=0), + nn.ReLU(), + ) + self.downblock_2 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.BatchNorm2d(64, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.BatchNorm2d(64, eps=1e-3, momentum=0), + nn.ReLU(), + ) + self.downblock_3 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 128, kernel_size=4, stride=2), + nn.BatchNorm2d(128, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 128, kernel_size=3, stride=1), + nn.BatchNorm2d(128, eps=1e-3, momentum=0), + nn.ReLU(), + ) + self.downblock_4 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 256, kernel_size=4, stride=2), + nn.BatchNorm2d(256, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, kernel_size=3, stride=1), + nn.BatchNorm2d(256, eps=1e-3, momentum=0), + nn.ReLU(), + ) + self.downblock_5 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 512, kernel_size=4, stride=2), + nn.BatchNorm2d(512, eps=1e-3, momentum=0), + nn.ReLU(), + ) + self.downblock_6 = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, kernel_size=3, stride=1), + nn.BatchNorm2d(512, eps=1e-3, momentum=0), + nn.ReLU(), + ) + + self.upblock_1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bicubic"), + nn.ReflectionPad2d((1, 2, 1, 2)), + nn.Conv2d(1024, 512, kernel_size=4, stride=1), + nn.BatchNorm2d(512, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 256, kernel_size=3, stride=1), + nn.BatchNorm2d(256, eps=1e-3, momentum=0), + nn.ReLU(), + ) + + self.upblock_2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bicubic"), + nn.ReflectionPad2d((1, 2, 1, 2)), + nn.Conv2d(512, 256, kernel_size=4, stride=1), + nn.BatchNorm2d(256, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 128, kernel_size=3, stride=1), + nn.BatchNorm2d(128, eps=1e-3, momentum=0), + nn.ReLU(), + ) + + self.upblock_3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bicubic"), + nn.ReflectionPad2d((1, 2, 1, 2)), + nn.Conv2d(256, 128, kernel_size=4, stride=1), + nn.BatchNorm2d(128, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 64, kernel_size=3, stride=1), + nn.BatchNorm2d(64, eps=1e-3, momentum=0), + nn.ReLU(), + ) + + self.upblock_4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bicubic"), + nn.ReflectionPad2d((1, 2, 1, 2)), + nn.Conv2d(128, 64, kernel_size=4, stride=1), + nn.BatchNorm2d(64, eps=1e-3, momentum=0), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 32, kernel_size=3, stride=1), + nn.BatchNorm2d(32, eps=1e-3, momentum=0), + nn.ReLU(), + ) + + self.last_pad = nn.ReflectionPad2d((1, 1, 1, 1)) + self.last_conv = nn.Conv2d(64, 1, kernel_size=3, stride=1) + + def forward(self, x): + d1 = self.downblock_1(x) + d2 = self.downblock_2(d1) + d3 = self.downblock_3(d2) + d4 = self.downblock_4(d3) + d5 = self.downblock_5(d4) + d6 = self.downblock_6(d5) + + u1 = torch.cat((d5, d6), dim=1) + u1 = self.upblock_1(u1) + u2 = torch.cat((d4, u1), dim=1) + u2 = self.upblock_2(u2) + u3 = torch.cat((d3, u2), dim=1) + u3 = self.upblock_3(u3) + u4 = torch.cat((d2, u3), dim=1) + u4 = self.upblock_4(u4) + u5 = torch.cat((d1, u4), dim=1) + + out = self.last_conv(self.last_pad(u5)) + + return out + + def proceed(self, img): + img = np.array(img) + blurred = cv2.GaussianBlur(img, (0, 0), 3) + img = img.astype(int) - blurred.astype(int) + img = img.astype(np.float32) / 127.5 + img /= np.max(img) + img = torch.tensor(img).unsqueeze(0).permute(3, 0, 1, 2).cuda() + img = self(img) + img = postprocess(img, thresh=0.1).unsqueeze(1) + return img \ No newline at end of file diff --git a/refnet/__init__.py b/refnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/refnet/ldm/__init__.py b/refnet/ldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3b5e44e026061c2566a44553e2ce92adf2e60b --- /dev/null +++ b/refnet/ldm/__init__.py @@ -0,0 +1 @@ +from .ddpm import LatentDiffusion \ No newline at end of file diff --git a/refnet/ldm/ddpm.py b/refnet/ldm/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..10ac55e062d5d76e53dc8d992febd7d271d63a0f --- /dev/null +++ b/refnet/ldm/ddpm.py @@ -0,0 +1,236 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from functools import partial + +from refnet.util import default, count_params, instantiate_from_config, exists +from refnet.ldm.util import make_beta_schedule, extract_into_tensor + + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps = 1000, + beta_schedule = "scaled_linear", + image_size = 256, + channels = 3, + linear_start = 1e-4, + linear_end = 2e-2, + cosine_s = 8e-3, + v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + parameterization = "eps", # all assuming fixed variance schedules + zero_snr = False, + half_precision_dtype = "float16", + version = "sdv1", + *args, + **kwargs + ): + super().__init__() + assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'" + assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default" + if zero_snr: + assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.' + + self.is_sdxl = (version == "sdxl") + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.img_embedder = None + self.image_size = image_size # try conv? + self.channels = channels + self.model = DiffusionWrapper(unet_config) + count_params(self.model, verbose=True) + self.v_posterior = v_posterior + self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16 + self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr) + + + def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s, zero_snr=zero_snr) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def add_noise(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def normalize_timesteps(self, timesteps): + return timesteps + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + first_stage_config, + cond_stage_config, + scale_factor = 1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scale_factor = scale_factor + self.first_stage_model, self.cond_stage_model = map( + lambda t: instantiate_from_config(t).eval().requires_grad_(False), + (first_stage_config, cond_stage_config) + ) + + @torch.no_grad() + def get_first_stage_encoding(self, x): + encoder_posterior = self.first_stage_model.encode(x) + z = encoder_posterior.sample() * self.scale_factor + return z.to(self.dtype).detach() + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach() + + def apply_model(self, x_noisy, t, cond): + return self.model(x_noisy, t, **cond) + + def get_learned_embedding(self, c, *args, **kwargs): + wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs)) + clip_emb = self.cond_stage_model.encode(c, **kwargs).detach() + return wd_emb, wd_logits, clip_emb + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + + def forward(self, x, t, **cond): + for k in cond: + if k in ["context", "y", "concat"]: + cond[k] = torch.cat(cond[k], 1) + + out = self.diffusion_model(x, t, **cond) + return out \ No newline at end of file diff --git a/refnet/ldm/openaimodel.py b/refnet/ldm/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..05a666c99ee5f7c5654ae7713a48e42b8dbec79f --- /dev/null +++ b/refnet/ldm/openaimodel.py @@ -0,0 +1,386 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from refnet.ldm.util import ( + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from refnet.util import checkpoint_wrapper + + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + @checkpoint_wrapper + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + @checkpoint_wrapper + def forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) \ No newline at end of file diff --git a/refnet/ldm/util.py b/refnet/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..550f84fe6c19389c6a19754c35074134c512d071 --- /dev/null +++ b/refnet/ldm/util.py @@ -0,0 +1,289 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False): + if schedule == "linear": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ) + elif schedule == "scaled_linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + + if zero_snr: + betas = rescale_zero_terminal_snr(betas) + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.amp.autocast("cuda", **ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.to(self.weight.dtype)).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/refnet/modules/__init__.py b/refnet/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc7b110c367b580747e9b1d0748ced2dd2dcfb9 --- /dev/null +++ b/refnet/modules/__init__.py @@ -0,0 +1,34 @@ +from collections import namedtuple + + +def wd_v14_swin2_tagger_config(): + CustomConfig = namedtuple('CustomConfig', [ + 'architecture', 'num_classes', 'num_features', 'global_pool', 'model_args', 'pretrained_cfg' + ]) + + custom_config = CustomConfig( + architecture="swinv2_base_window8_256", + num_classes=9083, + num_features=1024, + global_pool="avg", + model_args={ + "act_layer": "gelu", + "img_size": 448, + "window_size": 14 + }, + pretrained_cfg={ + "custom_load": False, + "input_size": [3, 448, 448], + "fixed_input_size": False, + "interpolation": "bicubic", + "crop_pct": 1.0, + "crop_mode": "center", + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 9083, + "pool_size": None, + "first_conv": None, + "classifier": None + } + ) + return custom_config \ No newline at end of file diff --git a/refnet/modules/attention.py b/refnet/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2bb322e76743ae100d0d3dd85657b6d1ee2a2a --- /dev/null +++ b/refnet/modules/attention.py @@ -0,0 +1,309 @@ +from calendar import c +import torch.nn as nn + +from einops import rearrange +from refnet.util import exists, default, checkpoint_wrapper +from .layers import RMSNorm +from .attn_utils import * + + +def create_masked_attention_bias( + mask: torch.Tensor, + threshold: float, + num_heads: int, + context_len: int +): + b, seq_len, _ = mask.shape + half_len = context_len // 2 + + if context_len % 8 != 0: + padded_context_len = ((context_len + 7) // 8) * 8 + else: + padded_context_len = context_len + + fg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype) + bg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype) + + fg_bias[:, :, half_len:] = -float('inf') + bg_bias[:, :, :half_len] = -float('inf') + attn_bias = torch.where(mask > threshold, fg_bias, bg_bias) + return attn_bias.unsqueeze(1).repeat_interleave(num_heads, dim=1) + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + +# Rotary Positional Embeddings implementation +class RotaryPositionalEmbeddings(nn.Module): + def __init__(self, dim, max_seq_len=1024, theta=10000.0): + super().__init__() + assert dim % 2 == 0, "Dimension must be divisible by 2" + dim = dim // 2 + self.max_seq_len = max_seq_len + freqs = torch.outer( + torch.arange(self.max_seq_len), + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + self.register_buffer("freq_h", freqs, persistent=False) + self.register_buffer("freq_w", freqs, persistent=False) + + def forward(self, x, grid_size): + bs, seq_len, heads = x.shape[:3] + h, w = grid_size + + x_complex = torch.view_as_complex( + x.float().reshape(bs, seq_len, heads, -1, 2) + ) + freqs = torch.cat([ + self.freq_h[:h].view(1, h, 1, -1).expand(bs, h, w, -1), + self.freq_w[:w].view(1, 1, w, -1).expand(bs, h, w, -1) + ], dim=-1).reshape(bs, seq_len, 1, -1) + + x_out = x_complex * freqs + x_out = torch.view_as_real(x_out).flatten(3) + + return x_out.type_as(x) + + +class MemoryEfficientAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, + query_dim, + context_dim = None, + heads = None, + dim_head = 64, + dropout = 0.0, + log = False, + causal = False, + rope = False, + max_seq_len = 1024, + qk_norm = False, + **kwargs + ): + super().__init__() + if log: + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + + heads = heads or query_dim // dim_head + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + self.q_norm = RMSNorm(inner_dim) if qk_norm else Identity() + self.k_norm = RMSNorm(inner_dim) if qk_norm else Identity() + self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_len=max_seq_len) if rope else Identity() + self.attn_ops = causal_ops if causal else {} + + # default setting for split cross-attention + self.bg_scale = 1. + self.fg_scale = 1. + self.merge_scale = 0. + self.mask_threshold = 0.05 + + @checkpoint_wrapper + def forward( + self, + x, + context=None, + mask=None, + scale=1., + scale_factor=None, + grid_size=None, + **kwargs, + ): + context = default(context, x) + + if exists(mask): + out = self.masked_forward(x, context, mask, scale, scale_factor) + else: + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + out = self.attn_forward(q, k, v, scale, grid_size) + + return self.to_out(out) + + def attn_forward(self, q, k, v, scale=1., grid_size=None, mask=None): + q, k = map( + lambda t: + self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size), + (self.q_norm(q), self.k_norm(k)) + ) + v = rearrange(v, "b n (h c) -> b n h c", h=self.heads) + out = attn_processor(q, k, v, attn_mask=mask, **self.attn_ops) * scale + out = rearrange(out, "b n h c -> b n (h c)") + return out + + def masked_forward(self, x, context, mask, scale=1., scale_factor=None): + # split cross-attention function + def qkv_forward(x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return q, k, v + + assert exists(scale_factor), "Scale factor must be assigned before masked attention" + mask = rearrange( + F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), + "b c h w -> b (h w) c" + ).contiguous() + + if self.merge_scale > 0: + # split cross-attention with merging scale, need two times forward + c1, c2 = context.chunk(2, dim=1) + + # Background region cross-attention + q2, k2, v2 = qkv_forward(x, c2) + bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale + + # Foreground region cross-attention + q1, k1, v1 = qkv_forward(x, c1) + fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale + + fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale + return torch.where(mask < self.mask_threshold, bg_out, fg_out) + + else: + attn_mask = create_masked_attention_bias( + mask, self.mask_threshold, self.heads, context.size(1) + ) + q, k, v = qkv_forward(x, context) + return self.attn_forward(q, k, v, mask=attn_mask) * scale + + +class MultiModalAttention(MemoryEfficientAttention): + def __init__(self, query_dim, context_dim_2, heads=8, dim_head=64, qk_norm=False, *args, **kwargs): + super().__init__(query_dim, heads=heads, dim_head=dim_head, qk_norm=qk_norm, *args, **kwargs) + inner_dim = dim_head * heads + self.to_k_2 = nn.Linear(context_dim_2, inner_dim, bias=False) + self.to_v_2 = nn.Linear(context_dim_2, inner_dim, bias=False) + self.k2_norm = RMSNorm(inner_dim) if qk_norm else Identity() + + def forward(self, x, context=None, mask=None, scale=1., grid_size=None): + if not isinstance(scale, list) and not isinstance(scale, tuple): + scale = (scale, scale) + assert len(context.shape) == 4, "Multi-modal attention requires different context inputs to be (b, m, n c)" + context, context2 = context.chunk(2, dim=1) + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + k2 = self.to_k_2(context2) + v2 = self.to_k_2(context2) + + b, _, _ = q.shape + q, k, k2 = map( + lambda t: self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size), + (self.q_norm(q), self.k_norm(k), self.k2_norm(k2)) + ) + v, v2 = map(lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), (v, v2)) + + out = (attn_processor(q, k, v, **self.attn_ops) * scale[0] + + attn_processor(q, k2, v2, **self.attn_ops) * scale[1]) + + if exists(mask): + raise NotImplementedError + out = rearrange(out, "b n h c -> b n (h c)") + return self.to_out(out) + + +class MultiScaleCausalAttention(MemoryEfficientAttention): + def forward( + self, + x, + context=None, + mask=None, + scale=1., + scale_factor=None, + grid_size=None, + token_lens=None + ): + context = default(context, x) + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + out = self.attn_forward(q, k, v, scale, grid_size=grid_size, token_lens=token_lens) + return self.to_out(out) + + def attn_forward(self, q, k, v, scale = 1., grid_size = None, token_lens = None): + q, k, v = map( + lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), + (self.q_norm(q), self.k_norm(k), v) + ) + + attn_output = [] + prev_idx = 0 + for idx, (grid, length) in enumerate(zip(grid_size, token_lens)): + end_idx = prev_idx + length + (idx == 0) + rope_prev_idx = prev_idx + (idx == 0) + rope_slice = slice(rope_prev_idx, end_idx) + + q[:, rope_slice] = self.rope(q[:, rope_slice], grid) + k[:, rope_slice] = self.rope(k[:, rope_slice], grid) + qs = q[:, prev_idx: end_idx] + ks, vs = map(lambda t: t[:, :end_idx], (k, v)) + + attn_output.append(attn_processor(qs.clone(), ks.clone(), vs.clone()) * scale) + prev_idx = end_idx + attn_output = rearrange(torch.cat(attn_output, 1), "b n h c -> b n (h c)") + return attn_output + + # if FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_AVAILABLE: + # k_chunks = [] + # v_chunks = [] + # kv_token_lens = [] + # prev_idx = 0 + # for idx, (grid, length) in enumerate(zip(grid_size, token_lens)): + # end_idx = prev_idx + length + (idx == 0) + # rope_prev_idx = prev_idx + (idx == 0) + + # rope_slice = slice(rope_prev_idx, end_idx) + # q[:, rope_slice], k[:, rope_slice], v[:, rope_slice] = map( + # lambda t: self.rope(t[:, rope_slice], grid), + # (q, k, v) + # ) + # kv_token_lens.append(end_idx+1) + # k_chunks.append(k[:, :end_idx]) + # v_chunks.append(v[:, :end_idx]) + # prev_idx = end_idx + # k = torch.cat(k_chunks, 1) + # v = torch.cat(v_chunks, 1) + # B, N, H, C = q.shape + # token_lens = torch.tensor(token_lens, device=q.device, dtype=torch.int32) + # kv_token_lens = torch.tensor(kv_token_lens, device=q.device, dtype=torch.int32) + # token_lens[0] = token_lens[0] + 1 + # + # cu_seqlens_q, cu_seqlens_kv = map(lambda t: + # torch.cat([t.new_zeros([1]), t]).cumsum(0, dtype=torch.int32), + # (token_lens, kv_token_lens) + # ) + # max_seqlen_q, max_seqlen_kv = map(lambda t: int(t.max()), (token_lens, kv_token_lens)) + # + # q_flat = q.reshape(-1, H, C).contiguous() + # k_flat = k.reshape(-1, H, C).contiguous() + # v_flat = v.reshape(-1, H, C).contiguous() + # out_flat = flash_attn_varlen_func( + # q=q_flat, k=k_flat, v=v_flat, + # cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k=cu_seqlens_kv, + # max_seqlen_q=max_seqlen_q, + # max_seqlen_k=max_seqlen_kv, + # causal=True, + # ) + # + # out = rearrange(out_flat, "(b n) h c -> b n (h c)", b=B, n=N) + # return out * scale diff --git a/refnet/modules/attn_utils.py b/refnet/modules/attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c54bef88bad3e0eb0871c04f2c024568a2c60f6 --- /dev/null +++ b/refnet/modules/attn_utils.py @@ -0,0 +1,155 @@ +import torch +import torch.nn.functional as F + +ATTN_PRECISION = torch.float16 + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True + FLASH_ATTN_AVAILABLE = False + +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + try: + import flash_attn + FLASH_ATTN_AVAILABLE = True + except ModuleNotFoundError: + FLASH_ATTN_AVAILABLE = False + +try: + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +def half(x): + if x.dtype not in [torch.float16, torch.bfloat16]: + x = x.to(ATTN_PRECISION) + return x + +def attn_processor(q, k, v, attn_mask = None, *args, **kwargs): + if attn_mask is not None: + if XFORMERS_IS_AVAILBLE: + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=attn_mask, *args, **kwargs + ) + else: + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, *args, **kwargs + ).transpose(1, 2) + else: + if FLASH_ATTN_3_AVAILABLE: + dtype = v.dtype + q, k, v = map(lambda t: half(t), (q, k, v)) + out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype) + elif FLASH_ATTN_AVAILABLE: + dtype = v.dtype + q, k, v = map(lambda t: half(t), (q, k, v)) + out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype) + elif XFORMERS_IS_AVAILBLE: + out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs) + else: + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2) + return out + + +def flash_attn_varlen_func(q, k, v, **kwargs): + if FLASH_ATTN_3_AVAILABLE: + return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0] + else: + return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs) + + +def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5): + """ + Split input tensor into foreground and background based on mask, then concatenate them. + + Args: + tensor: Input tensor of shape (batch, seq_len, dim) + mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len) + threshold: Threshold for mask binarization + + Returns: + split_tensor: Concatenated tensor with foreground first, then background + fg_indices: Indices of foreground elements for restoration + bg_indices: Indices of background elements for restoration + original_shape: Original tensor shape for restoration + """ + batch_size, seq_len, *dims = tensor.shape + device, dtype = tensor.device, tensor.dtype + + # Ensure mask has correct shape and binarize + if mask.dim() == 2: + mask = mask.unsqueeze(-1) + binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len) + + # Store indices for restoration (keep minimal loop for complex indexing) + fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)] + bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)] + + # Count elements efficiently + fg_counts = binary_mask.sum(dim=1) + bg_counts = (~binary_mask).sum(dim=1) + max_fg_len = fg_counts.max().item() + max_bg_len = bg_counts.max().item() + + # Early exit if no elements + if max_fg_len == 0 and max_bg_len == 0: + return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape + + # Create output tensor + split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype) + + # Vectorized approach using gather for better efficiency + for b in range(batch_size): + if len(fg_indices[b]) > 0: + split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]] + if len(bg_indices[b]) > 0: + split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]] + + return split_tensor, fg_indices, bg_indices, tensor.shape + + +def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list, + original_shape: torch.Size): + """ + Restore original tensor from split tensor using stored indices. + + Args: + split_tensor: Split tensor from split_tensor_by_mask + fg_indices: List of foreground indices for each batch + bg_indices: List of background indices for each batch + original_shape: Original tensor shape + + Returns: + restored_tensor: Restored tensor with original shape and ordering + """ + batch_size, seq_len = original_shape[:2] + dims = original_shape[2:] + device, dtype = split_tensor.device, split_tensor.dtype + + # Calculate split point efficiently + max_fg_len = max((len(fg) for fg in fg_indices), default=0) + + # Initialize restored tensor + restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype) + + # Early exit if no elements to restore + if split_tensor.shape[1] == 0: + return restored_tensor + + # Split tensor parts + fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None + bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None + + # Restore in single loop with efficient indexing + for b in range(batch_size): + if fg_part is not None and len(fg_indices[b]) > 0: + restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])] + if bg_part is not None and len(bg_indices[b]) > 0: + restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])] + + return restored_tensor diff --git a/refnet/modules/embedder.py b/refnet/modules/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..5c564f428ba7b8698b0c83567e10348f93ff9771 --- /dev/null +++ b/refnet/modules/embedder.py @@ -0,0 +1,489 @@ +import os +import math +import numpy as np + +from tqdm import tqdm +from einops import rearrange +from refnet.util import exists, append_dims +from refnet.sampling import tps_warp +from refnet.ldm.openaimodel import Timestep, zero_module + +import timm +import torch +import torch.nn as nn +import torchvision.transforms +import torch.nn.functional as F + +from huggingface_hub import hf_hub_download +from torch.utils.checkpoint import checkpoint +from safetensors.torch import load_file +from transformers import ( + T5EncoderModel, + T5Tokenizer, + CLIPVisionModelWithProjection, + CLIPTextModel, + CLIPTokenizer, +) + +versions = { + "ViT-bigG-14": "laion2b_s39b_b160k", + "ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32 + "ViT-L-14": "laion2b_s32b_b82k", + "hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H] +} +hf_versions = { + "ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + "ViT-L-14": "openai/clip-vit-large-patch14", +} +cache_dir = os.environ.get("HF_HOME", "./pretrained_models") + + +class WDv14SwinTransformerV2(nn.Module): + """ + WD-v14-tagger + Author: Smiling Wolf + Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2 + """ + negative_logit = -22 + + def __init__( + self, + input_size = 448, + antialias = True, + layer_idx = 0., + load_tag = False, + logit_threshold = None, + direct_forward = False, + ): + """ + + Args: + input_size: Input image size + antialias: Antialias during rescaling + layer_idx: Extracted feature layer + load_tag: Set it to true if use the embedder for image classification + logit_threshold: Filtering specific channels in logits output + """ + from refnet.modules import wd_v14_swin2_tagger_config + super().__init__() + custom_config = wd_v14_swin2_tagger_config() + self.model: nn.Module = timm.create_model( + custom_config.architecture, + pretrained = False, + num_classes = custom_config.num_classes, + global_pool = custom_config.global_pool, + **custom_config.model_args + ) + self.image_size = input_size + self.antialias = antialias + self.layer_idx = layer_idx + self.load_tag = load_tag + self.logit_threshold = logit_threshold + self.direct_forward = direct_forward + + self.load_from_pretrained_url(load_tag) + self.get_transformer_length() + self.model.eval() + self.model.requires_grad_(False) + + if self.direct_forward: + self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__) + + + def load_from_pretrained_url(self, load_tag=False): + import pandas as pd + from torch.hub import download_url_to_file + from data.tag_utils import load_labels, color_tag_index, geometry_tag_index + + ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors") + if not os.path.exists(ckpt_path): + cache_path = os.path.join(cache_dir, "weights.tmp") + download_url_to_file( + "https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors", + dst = cache_path + ) + os.rename(cache_path, ckpt_path) + + if load_tag: + csv_path = hf_hub_download( + "SmilingWolf/wd-v1-4-swinv2-tagger-v2", + "selected_tags.csv", + cache_dir = cache_dir + # use_auth_token=HF_TOKEN, + ) + tags_df = pd.read_csv(csv_path) + sep_tags = load_labels(tags_df) + + self.tag_names = sep_tags[0] + self.rating_indexes = sep_tags[1] + self.general_indexes = sep_tags[2] + self.character_indexes = sep_tags[3] + + self.color_tags = color_tag_index + self.expr_tags = geometry_tag_index + self.model.load_state_dict(load_file(ckpt_path)) + + + def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85): + assert self.load_tag + labels = list(zip(self.tag_names, pred[0].astype(float))) + + # First 4 labels are actually ratings: pick one with argmax + # ratings_names = [labels[i] for i in self.rating_indexes] + # rating = dict(ratings_names) + + # Then we have general tags: pick any where prediction confidence > threshold + general_names = [labels[i] for i in self.general_indexes] + + general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh] + general_res = dict(general_res) + + # Everything else is characters: pick any where prediction confidence > threshold + character_names = [labels[i] for i in self.character_indexes] + + character_res = [x for x in character_names if x[1] > character_thresh] + character_res = dict(character_res) + + sorted_general_strings = sorted( + general_res.items(), + key=lambda x: x[1], + reverse=True, + ) + + sorted_general_res = sorted( + general_res.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_general_strings = [x[0] for x in sorted_general_strings] + sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)") + + # return sorted_general_strings, rating, character_res, general_res + return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res + + def get_transformer_length(self): + length = 0 + for stage in self.model.layers: + length += len(stage.blocks) + self.transformer_length = length + + def transformer_forward(self, x): + idx = 0 + x = self.model.patch_embed(x) + for stage in self.model.layers: + x = stage.downsample(x) + for blk in stage.blocks: + if idx == self.transformer_length - self.layer_idx: + return x + if not torch.jit.is_scripting(): + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + idx += 1 + return x + + + def forward(self, x, return_logits=False, pooled=True, **kwargs): + # x: [b, h, w, 3] + if self.direct_forward: + x = self.model(x) + else: + x = self.transformer_forward(x) + x = self.model.norm(x) + + # x: [b, 14, 14, 1024] + if return_logits: + if pooled: + logits = self.model.forward_head(x).unsqueeze(1) + # x: [b, 1, 1024] + + else: + logits = self.model.head.fc(x) + # x = F.sigmoid(x) + logits = rearrange(logits, "b h w c -> b (h w) c").contiguous() + # x: [b, 196, 9083] + + # Need a threshold to cut off unnecessary classes. + if exists(self.logit_threshold) and isinstance(self.logit_threshold, float): + logits = torch.where( + logits > self.logit_threshold, + logits, + torch.ones_like(logits) * self.negative_logit + ) + + else: + logits = None + + if pooled: + x = x.mean(dim=[1, 2]).unsqueeze(1) + else: + x = rearrange(x, "b h w c -> b (h w) c").contiguous() + return [x, logits] + + def preprocess(self, x: torch.Tensor): + x = F.interpolate( + x, + (self.image_size, self.image_size), + mode = "bicubic", + align_corners = True, + antialias = self.antialias + ) + # convert RGB to BGR + x = x[:, [2, 1, 0]] + return x + + @torch.no_grad() + def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs): + # Input image must be in RGB format + return self(self.preprocess(img), return_logits, pooled) + + @torch.no_grad() + def predict_labels(self, img: torch.Tensor, *args, **kwargs): + assert len(img.shape) == 4 and img.shape[0] == 1 + logits = self(self.preprocess(img), return_logits=True, pooled=True)[1] + logits = F.sigmoid(logits).detach().cpu().numpy() + return self.convert_labels(logits, *args, **kwargs) + + def geometry_update(self, emb, geometry_emb, scale_factor=1): + """ + + Args: + emb: WD embedding from reference image + geometry_emb: WD embedding from sketch image + + """ + geometry_mask = torch.zeros_like(emb) + geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels + emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor + return emb + + @property + def dtype(self): + return self.model.head.fc.weight.dtype + + +class OpenCLIP(nn.Module): + def __init__(self, vision_config=None, text_config=None, **kwargs): + super().__init__() + if exists(vision_config): + vision_config.update(kwargs) + else: + vision_config = kwargs + + if exists(text_config): + text_config.update(kwargs) + else: + text_config = kwargs + + self.visual = FrozenOpenCLIPImageEmbedder(**vision_config) + self.transformer = FrozenOpenCLIPEmbedder(**text_config) + + def preprocess(self, x): + return self.visual.preprocess(x) + + @property + def scale_factor(self): + return self.visual.scale_factor + + def update_scale_factor(self, scale_factor): + self.visual.update_scale_factor(scale_factor) + + def encode(self, *args, **kwargs): + return self.visual.encode(*args, **kwargs) + + @torch.no_grad() + def encode_text(self, text, normalize=True): + return self.transformer(text, normalize) + + def calculate_scale(self, v: torch.Tensor, t: torch.Tensor): + """ + Calculate the projection of v along the direction of t + params: + v: visual tokens from clip image encoder, shape: (b, n, c) + t: text features from clip text encoder (argmax -1), shape: (b, 1, c) + """ + return v @ t.mT + + + +class HFCLIPVisionModel(nn.Module): + # TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future + def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.): + super().__init__() + self.model = CLIPVisionModelWithProjection.from_pretrained( + hf_versions[arch], + cache_dir = cache_dir + ) + self.image_size = image_size + self.scale_factor = scale_factor + self.register_buffer( + 'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False + ) + self.register_buffer( + 'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False + ) + self.antialias = True + self.requires_grad_(False).eval() + + def preprocess(self, x): + # normalize to [0,1] + ns = int(self.image_size * self.scale_factor) + x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias) + x = (x + 1.0) / 2.0 + + # renormalize according to clip + x = (x - self.mean) / self.std + return x + + def forward(self, x, output_type): + outputs = self.model(x).last_hidden_state + if output_type == "cls": + outputs = outputs[:, :1] + elif output_type == "local": + outputs = outputs[:, 1:] + outputs = self.model.vision_model.post_layernorm(outputs) + outputs = self.model.visual_projection(outputs) + return outputs + + @torch.no_grad() + def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs): + img = self.preprocess(img) if preprocess else img + + if warp_p > 0.: + rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim) + img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img)) + return self(img, output_type) + + + + +class FrozenT5Embedder(nn.Module): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir) + self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + @torch.no_grad() + def encode(self, text): + return self(text) + + +class HFCLIPTextEmbedder(nn.Module): + def __init__(self, arch, freeze=True, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained( + hf_versions[arch], + cache_dir = cache_dir + ) + self.model = CLIPTextModel.from_pretrained( + hf_versions[arch], + cache_dir = cache_dir + ) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.model = self.model.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + if isinstance(text, torch.Tensor) and text.dtype == torch.long: + # Input is already tokenized + tokens = text + else: + # Need to tokenize text input + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + + outputs = self.model(input_ids=tokens) + z = outputs.last_hidden_state + return z + + @torch.no_grad() + def encode(self, text, normalize=False): + outputs = self(text) + if normalize: + outputs = outputs / outputs.norm(dim=-1, keepdim=True) + return outputs + + +class ScalarEmbedder(nn.Module): + """embeds each dimension independently and concatenates them""" + + def __init__(self, embed_dim, out_dim): + super().__init__() + self.timestep = Timestep(embed_dim) + self.embed_layer = nn.Sequential( + nn.Linear(embed_dim, out_dim), + nn.SiLU(), + zero_module(nn.Linear(out_dim, out_features=out_dim)) + ) + + def forward(self, x, dtype=torch.float32): + emb = self.timestep(x) + emb = rearrange(emb, "b d -> b 1 d") + emb = self.embed_layer(emb.to(dtype)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.timestep = Timestep(embed_dim) + + def forward(self, x): + x = self.timestep(x) + return x + + +if __name__ == '__main__': + import PIL.Image as Image + + encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H") + image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB") + image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2 + image = image.unsqueeze(0) + print(image.shape) + feat = encoder.encode(image, "local") + print(feat.shape) \ No newline at end of file diff --git a/refnet/modules/encoder.py b/refnet/modules/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e111c95b8a64034be2c919ea3c6b8fb6cb40ed86 --- /dev/null +++ b/refnet/modules/encoder.py @@ -0,0 +1,224 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from refnet.util import checkpoint_wrapper +from refnet.modules.unet import TimestepEmbedSequential +from refnet.modules.layers import Upsample, zero_module, RMSNorm, FeedForward +from refnet.modules.attention import MemoryEfficientAttention, MultiScaleCausalAttention +from einops import rearrange +from functools import partial + + + +def make_zero_conv(in_channels, out_channels=None): + out_channels = out_channels or in_channels + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + +def activate_zero_conv(in_channels, out_channels=None): + out_channels = out_channels or in_channels + return TimestepEmbedSequential( + nn.SiLU(), + zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + ) + +def sequential_downsample(in_channels, out_channels, sequential_cls=nn.Sequential): + return sequential_cls( + nn.Conv2d(in_channels, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(256, out_channels, 3, padding=1)) + ) + + +class SimpleEncoder(nn.Module): + def __init__(self, c_channels, model_channels): + super().__init__() + self.model = sequential_downsample(c_channels, model_channels) + + def forward(self, x, *args, **kwargs): + return self.model(x) + + +class MultiEncoder(nn.Module): + def __init__(self, in_ch, model_channels, ch_mults, checkpoint=True, time_embed=False): + super().__init__() + sequential_cls = TimestepEmbedSequential if time_embed else nn.Sequential + output_chs = [model_channels * mult for mult in ch_mults] + self.model = sequential_downsample(in_ch, model_channels, sequential_cls) + self.zero_layer = make_zero_conv(output_chs[0]) + self.output_blocks = nn.ModuleList() + self.zero_blocks = nn.ModuleList() + + block_num = len(ch_mults) + prev_ch = output_chs[0] + for i in range(block_num): + self.output_blocks.append(sequential_cls( + nn.SiLU(), + nn.Conv2d(prev_ch, output_chs[i], 3, padding=1, stride=2 if i != block_num-1 else 1), + nn.SiLU(), + nn.Conv2d(output_chs[i], output_chs[i], 3, padding=1) + )) + self.zero_blocks.append( + TimestepEmbedSequential(make_zero_conv(output_chs[i])) if time_embed + else make_zero_conv(output_chs[i]) + ) + prev_ch = output_chs[i] + + self.checkpoint = checkpoint + + def forward(self, x): + x = self.model(x) + hints = [self.zero_layer(x)] + for layer, zero_layer in zip(self.output_blocks, self.zero_blocks): + x = layer(x) + hints.append(zero_layer(x)) + return hints + + +class MultiScaleAttentionEncoder(nn.Module): + def __init__( + self, + in_ch, + model_channels, + ch_mults, + dim_head = 128, + transformer_layers = 2, + checkpoint = True + ): + super().__init__() + conv_proj = partial(nn.Conv2d, kernel_size=1, padding=0) + output_chs = [model_channels * mult for mult in ch_mults] + block_num = len(ch_mults) + attn_ch = output_chs[-1] + + self.model = sequential_downsample(in_ch, output_chs[0]) + self.proj_ins = nn.ModuleList([conv_proj(output_chs[0], attn_ch)]) + self.proj_outs = nn.ModuleList([zero_module(conv_proj(attn_ch, output_chs[0]))]) + + prev_ch = output_chs[0] + self.downsample_layers = nn.ModuleList() + for i in range(block_num): + ch = output_chs[i] + self.downsample_layers.append(nn.Sequential( + nn.SiLU(), + nn.Conv2d(prev_ch, ch, 3, padding=1, stride=2 if i != block_num - 1 else 1), + )) + self.proj_ins.append(conv_proj(ch, attn_ch)) + self.proj_outs.append(zero_module(conv_proj(attn_ch, ch))) + prev_ch = ch + + self.proj_ins.append(conv_proj(attn_ch, attn_ch)) + self.attn_layer = MultiScaleCausalAttention(attn_ch, rope=True, qk_norm=True, dim_head=dim_head) + # self.transformer = nn.ModuleList([ + # BasicTransformerBlock( + # attn_ch, + # rotary_positional_embedding = True, + # qk_norm = True, + # d_head = dim_head, + # disable_cross_attn = True, + # self_attn_type = "multi-scale", + # ff_mult = 2, + # ) + # ] * transformer_layers) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x): + proj_in_iter = iter(self.proj_ins) + proj_out_iter = iter(self.proj_outs[::-1]) + + x = self.model(x) + hints = [rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c")] + grid_sizes = [(x.shape[2], x.shape[3])] + token_lens = [(x.shape[2] * x.shape[3])] + + for layer in self.downsample_layers: + x = layer(x) + h, w = x.shape[2], x.shape[3] + grid_sizes.append((h, w)) + token_lens.append(h * w) + hints.append(rearrange(next(proj_in_iter)(x), "b c h w -> b (h w) c")) + + hints.append(rearrange( + next(proj_in_iter)(x.mean(dim=[2, 3], keepdim=True)), + "b c h w -> b (h w) c" + )) + + hints = hints[::-1] + grid_sizes = grid_sizes[::-1] + token_lens = token_lens[::-1] + hints = torch.cat(hints, 1) + hints = self.attn_layer(hints, grid_size=grid_sizes, token_lens=token_lens) + hints + # for layer in self.transformer: + # hints = layer(hints, grid_size=grid_sizes, token_lens=token_lens) + + prev_idx = 1 + controls = [] + for gs, token_len in zip(grid_sizes, token_lens): + control = hints[:, prev_idx: prev_idx + token_len] + control = rearrange(control, "b (h w) c -> b c h w", h=gs[0], w=gs[1]) + controls.append(next(proj_out_iter)(control)) + prev_idx = prev_idx + token_len + return controls[::-1] + + + +class Downsampler(nn.Module): + def __init__(self, scale_factor): + super().__init__() + self.scale_factor = scale_factor + + def forward(self, x): + return F.interpolate(x, scale_factor=self.scale_factor, mode="bicubic") + + +class SpatialConditionEncoder(nn.Module): + def __init__( + self, + in_dim, + dim, + out_dim, + patch_size, + n_layers = 4, + ): + super().__init__() + self.patch_embed = nn.Conv2d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.conv = nn.Sequential(nn.SiLU(), nn.Conv2d(dim, dim, kernel_size=3, padding=1)) + + self.transformer = nn.ModuleList( + nn.ModuleList([ + RMSNorm(dim), + MemoryEfficientAttention(dim, rope=True), + RMSNorm(dim), + FeedForward(dim, mult=2) + ]) for _ in range(n_layers) + ) + self.out = nn.Sequential( + nn.SiLU(), + zero_module(nn.Conv2d(dim, out_dim, kernel_size=1, padding=0)) + ) + + def forward(self, x): + x = self.patch_embed(x) + x = self.conv(x) + + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + for norm, layer, norm2, ff in self.transformer: + x = layer(norm(x), grid_size=(h, w)) + x + x = ff(norm2(x)) + x + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + + return self.out(x) diff --git a/refnet/modules/layers.py b/refnet/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..68facc4a9e95d5f74b680dbf87a370f3b7eaae3a --- /dev/null +++ b/refnet/modules/layers.py @@ -0,0 +1,99 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from refnet.util import default + + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x \ No newline at end of file diff --git a/refnet/modules/lora.py b/refnet/modules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..076e16f407bfdff559ee545073a9eaeb529ef76e --- /dev/null +++ b/refnet/modules/lora.py @@ -0,0 +1,370 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Union, Dict, List +from einops import rearrange +from refnet.util import exists, default +from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock + + +def get_module_safe(self, module_path: str): + current_module = self + try: + for part in module_path.split('.'): + current_module = getattr(current_module, part) + return current_module + except AttributeError: + raise AttributeError(f"Cannot find modules {module_path}") + + +def switch_lora(self, v, label=None): + for t in [self.to_q, self.to_k, self.to_v]: + t.set_lora_active(v, label) + + +def lora_forward(self, x, context, mask, scale=1., scale_factor= None): + def qkv_forward(x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return q, k, v + + assert exists(scale_factor), "Scale factor must be assigned before masked attention" + + mask = rearrange( + F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), + "b c h w -> b (h w) c" + ).contiguous() + + c1, c2 = context.chunk(2, dim=1) + + # Background region cross-attention + if self.use_lora: + self.switch_lora(False, "foreground") + q2, k2, v2 = qkv_forward(x, c2) + bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale + + # Character region cross-attention + if self.use_lora: + self.switch_lora(True, "foreground") + q1, k1, v1 = qkv_forward(x, c1) + fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale + + fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale + return fg_out * mask + bg_out * (1 - mask) + # return torch.where(mask > self.mask_threshold, fg_out, bg_out) + + +def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None): + """ + This function hacks cross-attention layers. + Args: + x: Query input + context: Key and value input + mask: Character mask + scale: Attention scale + sacle_factor: Current latent size factor + + """ + def qkv_forward(x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return q, k, v + + assert exists(scale_factor), "Scale factor must be assigned before masked attention" + + mask = rearrange( + F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), + "b c h w -> b (h w) c" + ).contiguous() + + c1, c2 = context.chunk(2, dim=1) + + # Background region cross-attention + if self.use_lora: + self.switch_lora(True, "background") + self.switch_lora(False, "foreground") + q2, k2, v2 = qkv_forward(x, c2) + bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale + + # Foreground region cross-attention + if self.use_lora: + self.switch_lora(False, "background") + self.switch_lora(True, "foreground") + q1, k1, v1 = qkv_forward(x, c1) + fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale + + fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale + # return fg_out * mask + bg_out * (1 - mask) + return torch.where(mask > self.mask_threshold, fg_out, bg_out) + + + +class MultiLoraInjectedLinear(nn.Linear): + """ + A linear layer that can hold multiple LoRA adapters and merge them. + """ + def __init__( + self, + in_features, + out_features, + bias = False, + ): + super().__init__(in_features, out_features, bias) + self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}} + self.lora_scales: Dict[str, float] = {} + self.active_loras: Dict[str, bool] = {} + self.original_weight = None + self.original_bias = None + + # Freeze original weights + self.weight.requires_grad_(False) + if exists(self.bias): + self.bias.requires_grad_(False) + + def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0): + """Add a new LoRA adapter with the given label.""" + if isinstance(r, float): + r = int(r * self.out_features) + + lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None) + lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None) + dropout = nn.Dropout(dropout_p) + + # Initialize weights + nn.init.normal_(lora_down.weight, std=1 / r) + nn.init.zeros_(lora_up.weight) + + self.lora_adapters[label] = { + 'down': lora_down, + 'up': lora_up, + 'dropout': dropout, + } + self.lora_scales[label] = scale + self.active_loras[label] = True + + # Register as submodules + self.add_module(f'lora_down_{label}', lora_down) + self.add_module(f'lora_up_{label}', lora_up) + self.add_module(f'lora_dropout_{label}', dropout) + + def get_trainable_layers(self, label: str = None): + """Get trainable layers for specific LoRA or all LoRAs.""" + layers = [] + if exists(label): + if label in self.lora_adapters: + adapter = self.lora_adapters[label] + layers.extend([adapter['down'], adapter['up']]) + else: + for adapter in self.lora_adapters.values(): + layers.extend([adapter['down'], adapter['up']]) + return layers + + def set_lora_active(self, active: bool, label: str): + """Activate or deactivate a specific LoRA adapter.""" + if label in self.active_loras: + self.active_loras[label] = active + + def set_lora_scale(self, scale: float, label: str): + """Set the scale for a specific LoRA adapter.""" + if label in self.lora_scales: + self.lora_scales[label] = scale + + def merge_lora_weights(self, labels: List[str] = None): + """Merge specified LoRA adapters into the base weights.""" + if labels is None: + labels = list(self.lora_adapters.keys()) + + # Store original weights if not already stored + if self.original_weight is None: + self.original_weight = self.weight.clone() + if exists(self.bias): + self.original_bias = self.bias.clone() + + merged_weight = self.original_weight.clone() + merged_bias = self.original_bias.clone() if exists(self.original_bias) else None + + for label in labels: + if label in self.lora_adapters and self.active_loras.get(label, False): + lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down'] + scale = self.lora_scales[label] + + lora_weight = lora_up.weight @ lora_down.weight + merged_weight += scale * lora_weight + + if exists(merged_bias) and exists(lora_up.bias): + lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias + merged_bias += scale * lora_bias + + # Update weights + self.weight = nn.Parameter(merged_weight, requires_grad=False) + if exists(merged_bias): + self.bias = nn.Parameter(merged_bias, requires_grad=False) + + # Deactivate all LoRAs after merging + for label in labels: + self.active_loras[label] = False + + def recover_original_weight(self): + """Recover the original weights before any LoRA modifications.""" + if self.original_weight is not None: + self.weight = nn.Parameter(self.original_weight.clone()) + if exists(self.original_bias): + self.bias = nn.Parameter(self.original_bias.clone()) + + # Reactivate all LoRAs + for label in self.active_loras: + self.active_loras[label] = True + + def forward(self, input): + output = super().forward(input) + + # Add contributions from active LoRAs + for label, adapter in self.lora_adapters.items(): + if self.active_loras.get(label, False): + lora_out = adapter['up'](adapter['dropout'](adapter['down'](input))) + output += self.lora_scales[label] * lora_out + + return output + + +class LoraModules: + def __init__(self, sd, lora_params, *args, **kwargs): + self.modules = {} + self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer + + for cfg in lora_params: + root_module = get_module_safe(sd, cfg.pop("root_module")) + label = cfg.pop("label", "lora") + self.inject_lora(label, root_module, **cfg) + + def inject_lora( + self, + label, + root_module, + r, + split_forward = False, + target_keys = ("to_q", "to_k", "to_v"), + filter_keys = None, + target_class = None, + scale = 1.0, + dropout_p = 0.0, + ): + def check_condition(path, child, class_list): + if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys): + return False + if exists(target_keys) and any(path.endswith(key) for key in target_keys): + return True + if exists(class_list) and any( + isinstance(child, module_class) for module_class in class_list + ): + return True + return False + + def retrieve_target_modules(): + from refnet.util import get_obj_from_str + target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None + + modules = [] + for name, module in root_module.named_modules(): + for key, child in module._modules.items(): + full_path = name + '.' + key if name else key + if check_condition(full_path, child, target_class_list): + modules.append((module, child, key, full_path)) + return modules + + modules: list[Union[nn.Module]] = [] + retrieved_modules = retrieve_target_modules() + + for parent, child, child_name, full_path in retrieved_modules: + # Check if this layer already has a MultiLoraInjectedLinear + if full_path in self.multi_lora_layers: + # Add LoRA to existing MultiLoraInjectedLinear + multi_lora_layer = self.multi_lora_layers[full_path] + multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p) + else: + # Check if the current layer is already a MultiLoraInjectedLinear + if isinstance(child, MultiLoraInjectedLinear): + child.add_lora_adapter(label, r, scale, dropout_p) + self.multi_lora_layers[full_path] = child + else: + # Replace with MultiLoraInjectedLinear and add first LoRA + multi_lora_layer = MultiLoraInjectedLinear( + in_features=child.weight.shape[1], + out_features=child.weight.shape[0], + bias=exists(child.bias), + ) + + multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p) + parent._modules[child_name] = multi_lora_layer + self.multi_lora_layers[full_path] = multi_lora_layer + + if split_forward: + parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__) + else: + parent.masked_forward = lora_forward.__get__(parent, parent.__class__) + + parent.use_lora = True + parent.switch_lora = switch_lora.__get__(parent, parent.__class__) + modules.append(parent) + + self.modules[label] = modules + print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers") + return self.multi_lora_layers, modules + + def get_trainable_layers(self, label = None): + """Get all trainable layers, optionally filtered by label.""" + layers = [] + for lora_layer in self.multi_lora_layers.values(): + layers += lora_layer.get_trainable_layers(label) + return layers + + def switch_lora(self, mode, label = None): + if exists(label): + for layer in self.multi_lora_layers.values(): + layer.set_lora_active(mode, label) + for module in self.modules[label]: + module.use_lora = mode + else: + for layer in self.multi_lora_layers.values(): + for lora_label in layer.lora_adapters.keys(): + layer.set_lora_active(mode, lora_label) + + for modules in self.modules.values(): + for module in modules: + module.use_lora = mode + + def adjust_lora_scales(self, scale, label = None): + if exists(label): + for layer in self.multi_lora_layers.values(): + layer.set_lora_scale(scale, label) + else: + for layer in self.multi_lora_layers.values(): + for lora_label in layer.lora_adapters.keys(): + layer.set_lora_scale(scale, lora_label) + + def merge_lora(self, labels = None): + if labels is None: + labels = list(self.modules.keys()) + elif isinstance(labels, str): + labels = [labels] + + for layer in self.multi_lora_layers.values(): + layer.merge_lora_weights(labels) + + def recover_lora(self): + for layer in self.multi_lora_layers.values(): + layer.recover_original_weight() + + def get_lora_info(self): + """Get information about all LoRA adapters.""" + info = {} + for path, layer in self.multi_lora_layers.items(): + info[path] = { + 'labels': list(layer.lora_adapters.keys()), + 'active': {label: active for label, active in layer.active_loras.items()}, + 'scales': layer.lora_scales.copy() + } + return info \ No newline at end of file diff --git a/refnet/modules/proj.py b/refnet/modules/proj.py new file mode 100644 index 0000000000000000000000000000000000000000..b619a088f21ef1774a16ba33179fd968bd2e101c --- /dev/null +++ b/refnet/modules/proj.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn + +from refnet.modules.layers import zero_module +from refnet.modules.attention import MemoryEfficientAttention +from refnet.modules.transformer import BasicTransformerBlock +from refnet.util import checkpoint_wrapper, exists +from refnet.util import load_weights + + +class NormalizedLinear(nn.Module): + def __init__(self, dim, output_dim, checkpoint=True): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(dim, output_dim), + nn.LayerNorm(output_dim) + ) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x): + return self.layers(x) + + +class GlobalProjection(nn.Module): + def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True): + super().__init__() + self.c_dim = output_dim + self.dim_head = dim_head + self.head = (heads[0], heads[0] * heads[1]) + + self.proj1 = nn.Linear(input_dim, dim_head * heads[0]) + self.proj2 = nn.Sequential( + nn.SiLU(), + zero_module(nn.Linear(dim_head, output_dim * heads[1])), + ) + self.norm = nn.LayerNorm(output_dim) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x): + x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous() + x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous() + return self.norm(x) + + +class ClusterConcat(nn.Module): + def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True): + super().__init__() + self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head) + self.norm = nn.LayerNorm(input_dim) + self.proj = nn.Sequential( + nn.Linear(input_dim + c_dim, output_dim), + nn.SiLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim) + ) + self.token_length = token_length + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x, emb, fgbg=False, *args, **kwargs): + x = self.attn(x)[:, :self.token_length] + x = self.norm(x) + x = torch.cat([x, emb], 2) + x = self.proj(x) + + if fgbg: + x = torch.cat(torch.chunk(x, 2), 1) + return x + + +class RecoveryClusterConcat(ClusterConcat): + def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs): + super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs) + self.transformer = BasicTransformerBlock( + output_dim, output_dim//dim_head, dim_head, + disable_cross_attn=True, checkpoint=False + ) + + @checkpoint_wrapper + def forward(self, x, emb, bg=False): + x = self.attn(x)[:, :self.token_length] + x = self.norm(x) + x = torch.cat([x, emb], 2) + x = self.proj(x) + + if bg: + x = self.transformer(x) + return x + + +class LogitClusterConcat(ClusterConcat): + def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs): + super().__init__(c_dim=c_dim, *args, **kwargs) + self.mlp = AdaptiveMLP(c_dim, mlp_in_dim) + if exists(mlp_ckpt_path): + self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True) + + @checkpoint_wrapper + def forward(self, x, emb, bg=False): + with torch.no_grad(): + emb = self.mlp(emb).detach() + return super().forward(x, emb, bg) + + +class AdaptiveMLP(nn.Module): + def __init__(self, dim, in_dim, layers=4, checkpoint=True): + super().__init__() + + model = [nn.Sequential(nn.Linear(in_dim, dim))] + for i in range(1, layers): + model += [nn.Sequential( + nn.SiLU(), + nn.LayerNorm(dim), + nn.Linear(dim, dim) + )] + self.mlp = nn.Sequential(*model) + self.fusion_layer = nn.Linear(dim * layers, dim, bias=False) + self.norm = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x): + fx = [] + + for layer in self.mlp: + x = layer(x) + fx.append(x) + + x = torch.cat(fx, dim=2) + out = self.fusion_layer(x) + out = self.norm(out) + return out + + +class Concat(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x, y, *args, **kwargs): + return torch.cat([x, y], dim=-1) \ No newline at end of file diff --git a/refnet/modules/reference_net.py b/refnet/modules/reference_net.py new file mode 100644 index 0000000000000000000000000000000000000000..020fcd8337c990e347f0c7627404cb8db6ec7668 --- /dev/null +++ b/refnet/modules/reference_net.py @@ -0,0 +1,430 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange +from typing import Union +from functools import partial + +from refnet.modules.unet_old import ( + timestep_embedding, + conv_nd, + TimestepEmbedSequential, + exists, + ResBlock, + linear, + Downsample, + zero_module, + SelfTransformerBlock, + SpatialTransformer, +) +from refnet.modules.unet import DualCondUNetXL + + +def hack_inference_forward(model): + model.forward = InferenceForward.__get__(model, model.__class__) + + +def hack_unet_forward(unet): + unet.original_forward = unet._forward + if isinstance(unet, DualCondUNetXL): + unet._forward = enhanced_forward_xl.__get__(unet, unet.__class__) + else: + unet._forward = enhanced_forward.__get__(unet, unet.__class__) + + +def restore_unet_forward(unet): + if hasattr(unet, "original_forward"): + unet._forward = unet.original_forward.__get__(unet, unet.__class__) + del unet.original_forward + + +def modulation(x, scale, shift): + return x * (1 + scale) + shift + + +def enhanced_forward( + self, + x: torch.Tensor, + emb: torch.Tensor, + hs_fg: torch.Tensor = None, + hs_bg: torch.Tensor = None, + mask: torch.Tensor = None, + threshold: Union[float|torch.Tensor] = None, + control: torch.Tensor = None, + context: torch.Tensor = None, + style_modulations: torch.Tensor = None, + **additional_context +): + h = x.to(self.dtype) + emb = emb.to(self.dtype) + hs = [] + + control_iter = iter(control) + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context, mask, **additional_context) + + if idx in self.hint_encoder_index: + h += next(control_iter) + + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for idx, module in enumerate(self.output_blocks): + h_skip = hs.pop() + + if exists(mask) and exists(threshold): + # inject foreground/background features + B, C, H, W = h_skip.shape + cm = F.interpolate(mask, (H, W), mode="bicubic") + h = torch.cat([h, torch.where( + cm > threshold, + self.map_modules[idx](h_skip, hs_fg[idx]) if exists(hs_fg) else h_skip, + self.warp_modules[idx](h_skip, hs_bg[idx]) if exists(hs_bg) else h_skip + )], 1) + + else: + h = torch.cat([h, h_skip], 1) + + h = module(h, emb, context, mask, **additional_context) + + if exists(style_modulations): + style_norm, emb_proj, style_proj = self.style_modules[idx] + style_m = style_modulations[idx] + emb_proj(emb) + style_m = style_proj(style_norm(style_m))[...,None,None] + scale, shift = style_m.chunk(2, dim=1) + + h = modulation(h, scale, shift) + + return h + +def enhanced_forward_xl( + self, + x: torch.Tensor, + emb, + z_fg: torch.Tensor = None, + z_bg: torch.Tensor = None, + hs_fg: torch.Tensor = None, + hs_bg: torch.Tensor = None, + mask: torch.Tensor = None, + inject_mask: torch.Tensor = None, + threshold: Union[float|torch.Tensor] = None, + concat: torch.Tensor = None, + control: torch.Tensor = None, + context: torch.Tensor = None, + style_modulations: torch.Tensor = None, + **additional_context +): + h = x.to(self.dtype) + emb = emb.to(self.dtype) + hs = [] + control_iter = iter(control) + + if exists(concat): + h = torch.cat([h, concat], 1) + h = h + self.concat_conv(h) + + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context, mask, **additional_context) + + if idx in self.hint_encoder_index: + h += next(control_iter) + + if exists(z_fg): + h += self.conv_fg(z_fg) + z_fg = None + if exists(z_bg): + h += self.conv_bg(z_bg) + z_bg = None + + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for idx, module in enumerate(self.output_blocks): + h_skip = hs.pop() + + if exists(inject_mask) and exists(threshold): + # inject foreground/background features + B, C, H, W = h_skip.shape + cm = F.interpolate(inject_mask, (H, W), mode="bicubic") + h = torch.cat([h, torch.where( + cm > threshold, + + # foreground injection + rearrange( + self.map_modules[idx][0]( + rearrange(h_skip, "b c h w -> b (h w) c"), + hs_fg[idx] + self.map_modules[idx][1](emb).unsqueeze(1) + ), "b (h w) c -> b c h w", h=H, w=W + ) + h_skip if exists(hs_fg) else h_skip, + + # background injection + rearrange( + self.warp_modules[idx][0]( + rearrange(h_skip, "b c h w -> b (h w) c"), + hs_bg[idx] + self.warp_modules[idx][1](emb).unsqueeze(1) + ), "b (h w) c -> b c h w", h=H, w=W + ) + h_skip if exists(hs_bg) else h_skip + )], 1) + + else: + h = torch.cat([h, h_skip], 1) + + h = module(h, emb, context, mask, **additional_context) + + if exists(style_modulations): + style_norm, emb_proj, style_proj = self.style_modules[idx] + style_m = style_modulations[idx] + emb_proj(emb) + style_m = style_proj(style_norm(style_m))[...,None,None] + scale, shift = style_m.chunk(2, dim=1) + + h = modulation(h, scale, shift) + + if idx in self.hint_decoder_index: + h += next(control_iter) + + return h + +def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb).to(self.dtype) + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + emb = emb.to(self.dtype) + return self._forward(x, emb, *args, **kwargs) + + +class UNetEncoderXL(nn.Module): + transformers = { + "vanilla": SpatialTransformer, + } + + def __init__( + self, + in_channels, + model_channels, + num_res_blocks, + attention_resolutions, + dropout = 0, + channel_mult = (1, 2, 4, 8), + conv_resample = True, + dims = 2, + num_classes = None, + use_checkpoint = False, + num_heads = -1, + num_head_channels = -1, + use_scale_shift_norm = False, + resblock_updown = False, + use_spatial_transformer = False, # custom transformer support + transformer_depth = 1, # custom transformer support + context_dim = None, # custom transformer support + disable_self_attentions = None, + disable_cross_attentions = None, + num_attention_blocks = None, + use_linear_in_transformer = False, + adm_in_channels = None, + transformer_type = "vanilla", + style_modulation = False, + ): + super().__init__() + if use_spatial_transformer: + assert exists( + context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + time_embed_dim = model_channels * 4 + resblock = partial( + ResBlock, + emb_channels=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + + transformer = partial( + self.transformers[transformer_type], + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + disable_self_attn=disable_self_attentions, + disable_cross_attn=disable_cross_attentions, + ) + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.style_modulation = style_modulation + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + + time_embed_dim = model_channels * 4 + zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_layers = nn.ModuleList([zero_module( + nn.Linear(model_channels, model_channels * 2) if style_modulation else + zero_conv(model_channels, model_channels) + )]) + + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, num_head_channels) + if not use_spatial_transformer + else transformer( + ch, num_heads, num_head_channels, depth=transformer_depth[level] + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_layers.append(zero_module( + nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch) + )) + + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append(TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + )) + self.zero_layers.append(zero_module( + nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else + zero_conv(out_ch, out_ch) + )) + ch = out_ch + ds *= 2 + + + def forward(self, x, timesteps = None, y = None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + + hs = self._forward(x, emb, *args, **kwargs) + return hs + + def _forward(self, x, emb, context = None, **additional_context): + hints = [] + h = x.to(self.dtype) + + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context, **additional_context) + + if self.style_modulation: + hint = self.zero_layers[idx](h.mean(dim=[2, 3])) + hints.append(hint) + + else: + hint = self.zero_layers[idx](h) + hint = rearrange(hint, "b c h w -> b (h w) c").contiguous() + hints.append(hint) + + hints.reverse() + return hints \ No newline at end of file diff --git a/refnet/modules/transformer.py b/refnet/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c6406307020e811617b812fd45bafbc4bcf0d079 --- /dev/null +++ b/refnet/modules/transformer.py @@ -0,0 +1,232 @@ +import torch +import torch.nn as nn + +from functools import partial +from einops import rearrange + +from refnet.util import checkpoint_wrapper, exists +from refnet.modules.layers import FeedForward, Normalize, zero_module, RMSNorm +from refnet.modules.attention import MemoryEfficientAttention, MultiModalAttention, MultiScaleCausalAttention + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "vanilla": MemoryEfficientAttention, + "multi-scale": MultiScaleCausalAttention, + "multi-modal": MultiModalAttention, + } + def __init__( + self, + dim, + n_heads = None, + d_head = 64, + dropout = 0., + context_dim = None, + gated_ff = True, + ff_mult = 4, + checkpoint = True, + disable_self_attn = False, + disable_cross_attn = False, + self_attn_type = "vanilla", + cross_attn_type = "vanilla", + rotary_positional_embedding = False, + context_dim_2 = None, + casual_self_attn = False, + casual_cross_attn = False, + qk_norm = False, + norm_type = "layer", + ): + super().__init__() + assert self_attn_type in self.ATTENTION_MODES + assert cross_attn_type in self.ATTENTION_MODES + self_attn_cls = self.ATTENTION_MODES[self_attn_type] + crossattn_cls = self.ATTENTION_MODES[cross_attn_type] + + if norm_type == "layer": + norm_cls = nn.LayerNorm + elif norm_type == "rms": + norm_cls = RMSNorm + else: + raise NotImplementedError(f"Normalization {norm_type} is not implemented.") + + self.dim = dim + self.disable_self_attn = disable_self_attn + self.disable_cross_attn = disable_cross_attn + + self.attn1 = self_attn_cls( + query_dim = dim, + heads = n_heads, + dim_head = d_head, + dropout = dropout, + context_dim = context_dim if self.disable_self_attn else None, + casual = casual_self_attn, + rope = rotary_positional_embedding, + qk_norm = qk_norm + ) + self.attn2 = crossattn_cls( + query_dim = dim, + context_dim = context_dim, + context_dim_2 = context_dim_2, + heads = n_heads, + dim_head = d_head, + dropout = dropout, + casual = casual_cross_attn + ) if not disable_cross_attn else None + + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, mult=ff_mult) + self.norm1 = norm_cls(dim) + self.norm2 = norm_cls(dim) if not disable_cross_attn else None + self.norm3 = norm_cls(dim) + self.reference_scale = 1 + self.scale_factor = None + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x, context=None, mask=None, emb=None, **kwargs): + x = self.attn1(self.norm1(x), **kwargs) + x + if not self.disable_cross_attn: + x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SelfInjectedTransformerBlock(BasicTransformerBlock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.bank = None + self.time_proj = None + self.injection_type = "concat" + self.forward_without_bank = super().forward + + @checkpoint_wrapper + def forward(self, x, context=None, mask=None, emb=None, **kwargs): + if exists(self.bank): + bank = self.bank + if bank.shape[0] != x.shape[0]: + bank = bank.repeat(x.shape[0], 1, 1) + if exists(self.time_proj) and exists(emb): + bank = bank + self.time_proj(emb).unsqueeze(1) + x_in = self.norm1(x) + + self.attn1.mask_threshold = self.attn2.mask_threshold + x = self.attn1( + x_in, + torch.cat([x_in, bank], 1) if self.injection_type == "concat" else x_in + bank, + mask = mask, + scale_factor = self.scale_factor, + **kwargs + ) + x + + x = self.attn2( + self.norm2(x), + context, + mask = mask, + scale = self.reference_scale, + scale_factor = self.scale_factor + ) + x + + x = self.ff(self.norm3(x)) + x + else: + x = self.forward_without_bank(x, context, mask, emb) + return x + + +class SelfTransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_head = 64, + dropout = 0., + mlp_ratio = 4, + checkpoint = True, + casual_attn = False, + reshape = True + ): + super().__init__() + self.attn = MemoryEfficientAttention(query_dim=dim, heads=dim//dim_head, dropout=dropout, casual=casual_attn) + self.ff = nn.Sequential( + nn.Linear(dim, dim * mlp_ratio), + nn.SiLU(), + zero_module(nn.Linear(dim * mlp_ratio, dim)) + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.reshape = reshape + self.checkpoint = checkpoint + + @checkpoint_wrapper + def forward(self, x, context=None): + b, c, h, w = x.shape + if self.reshape: + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + + x = self.attn(self.norm1(x), context if exists(context) else None) + x + x = self.ff(self.norm2(x)) + x + + if self.reshape: + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + return x + + +class Transformer(nn.Module): + transformer_type = { + "vanilla": BasicTransformerBlock, + "self-injection": SelfInjectedTransformerBlock, + } + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, use_linear=False, + use_checkpoint=True, type="vanilla", transformer_config=None, **kwargs): + super().__init__() + transformer_block = self.transformer_type[type] + if not isinstance(context_dim, list): + context_dim = [context_dim] + if isinstance(context_dim, list): + if depth != len(context_dim): + context_dim = depth * [context_dim[0]] + + proj_layer = nn.Linear if use_linear else partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) + inner_dim = n_heads * d_head + + self.in_channels = in_channels + self.proj_in = proj_layer(in_channels, inner_dim) + self.transformer_blocks = nn.ModuleList([ + transformer_block( + inner_dim, + n_heads, + d_head, + dropout = dropout, + context_dim = context_dim[d], + checkpoint = use_checkpoint, + **(transformer_config or {}), + **kwargs + ) for d in range(depth) + ]) + self.proj_out = zero_module(proj_layer(inner_dim, in_channels)) + self.norm = Normalize(in_channels) + self.use_linear = use_linear + + def forward(self, x, context=None, mask=None, emb=None, *args, **additional_context): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context, mask=mask, emb=emb, grid_size=(h, w), *args, **additional_context) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +def SpatialTransformer(*args, **kwargs): + return Transformer(type="vanilla", *args, **kwargs) + +def SelfInjectTransformer(*args, **kwargs): + return Transformer(type="self-injection", *args, **kwargs) diff --git a/refnet/modules/unet.py b/refnet/modules/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ba95f7c54dd33f7752ccee856a5e754dabc54c --- /dev/null +++ b/refnet/modules/unet.py @@ -0,0 +1,421 @@ +import torch +import torch.nn as nn + +from functools import partial +from refnet.modules.attention import MemoryEfficientAttention +from refnet.util import exists +from refnet.modules.transformer import ( + SelfTransformerBlock, + Transformer, + SpatialTransformer, + SelfInjectTransformer, +) +from refnet.ldm.openaimodel import ( + timestep_embedding, + conv_nd, + TimestepBlock, + zero_module, + ResBlock, + linear, + Downsample, + Upsample, + normalization, +) + + +def hack_inference_forward(model): + model.forward = InferenceForward.__get__(model, model.__class__) + + +def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb).to(self.dtype) + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(emb.device)) + emb = emb.to(self.dtype) + h = self._forward(x, emb, *args, **kwargs) + return self.out(h.to(x.dtype)) + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + # Dispatch constants + _D_TIMESTEP = 0 + _D_TRANSFORMER = 1 + _D_OTHER = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Cache dispatch types at init (before FSDP wrapping), so forward() + # needs no isinstance checks and is immune to FSDP wrapper breakage. + self._dispatch = tuple( + self._D_TIMESTEP if isinstance(layer, TimestepBlock) else + self._D_TRANSFORMER if isinstance(layer, Transformer) else + self._D_OTHER + for layer in self + ) + + def forward(self, x, emb=None, context=None, mask=None, **additional_context): + for layer, d in zip(self, self._dispatch): + if d == self._D_TIMESTEP: + x = layer(x, emb) + elif d == self._D_TRANSFORMER: + x = layer(x, context, mask, emb, **additional_context) + else: + x = layer(x) + return x + + + +class UNetModel(nn.Module): + transformers = { + "vanilla": SpatialTransformer, + "selfinj": SelfInjectTransformer, + } + def __init__( + self, + in_channels, + model_channels, + num_res_blocks, + attention_resolutions, + out_channels = 4, + dropout = 0, + channel_mult = (1, 2, 4, 8), + conv_resample = True, + dims = 2, + num_classes = None, + use_checkpoint = False, + num_heads = -1, + num_head_channels = -1, + use_scale_shift_norm = False, + resblock_updown = False, + use_spatial_transformer = False, # custom transformer support + transformer_depth = 1, # custom transformer support + context_dim = None, # custom transformer support + disable_self_attentions = None, + disable_cross_attentions = False, + num_attention_blocks = None, + use_linear_in_transformer = False, + adm_in_channels = None, + transformer_type = "vanilla", + map_module = False, + warp_module = False, + style_modulation = False, + discard_final_layers = False, # for reference net + additional_transformer_config = None, + in_channels_fg = None, + in_channels_bg = None, + ): + super().__init__() + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set' + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.num_classes = num_classes + self.model_channels = model_channels + self.dtype = torch.float32 + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = transformer_depth[-1] + time_embed_dim = model_channels * 4 + resblock = partial( + ResBlock, + emb_channels = time_embed_dim, + dropout = dropout, + dims = dims, + use_checkpoint = use_checkpoint, + use_scale_shift_norm = use_scale_shift_norm, + ) + transformer = partial( + self.transformers[transformer_type], + context_dim = context_dim, + use_linear = use_linear_in_transformer, + use_checkpoint = use_checkpoint, + disable_self_attn = disable_self_attentions, + disable_cross_attn = disable_cross_attentions, + transformer_config = additional_transformer_config + ) + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [resblock(ch, out_channels=mult * model_channels)] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, current_head_dim) + if not use_spatial_transformer + else transformer( + ch, current_num_heads, current_head_dim, + depth=transformer_depth[level], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append(TimestepEmbedSequential( + resblock(ch, out_channels=out_ch, down=True) if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + )) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + self.middle_block = TimestepEmbedSequential( + resblock(ch), + SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer + else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle), + resblock(ch), + ) + + self.output_blocks = nn.ModuleList([]) + self.map_modules = nn.ModuleList([]) + self.warp_modules = nn.ModuleList([]) + self.style_modules = nn.ModuleList([]) + + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [resblock(ch + ich, out_channels=model_channels * mult)] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer + else transformer( + ch, current_num_heads, current_head_dim, depth=transformer_depth[level] + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + resblock(ch, up=True) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + if level == 0 and discard_final_layers: + break + + if map_module: + self.map_modules.append(nn.ModuleList([ + MemoryEfficientAttention( + ich, + heads = ich // num_head_channels, + dim_head = num_head_channels + ), + nn.Linear(time_embed_dim, ich) + ])) + + if warp_module: + self.warp_modules.append(nn.ModuleList([ + MemoryEfficientAttention( + ich, + heads = ich // num_head_channels, + dim_head = num_head_channels + ), + nn.Linear(time_embed_dim, ich) + ])) + + # self.warp_modules.append(nn.ModuleList([ + # SpatialTransformer(ich, ich//num_head_channels, num_head_channels), + # nn.Linear(time_embed_dim, ich) + # ])) + + if style_modulation: + self.style_modules.append(nn.ModuleList([ + nn.LayerNorm(ch*2), + nn.Linear(time_embed_dim, ch*2), + zero_module(nn.Linear(ch*2, ch*2)) + ])) + + if not discard_final_layers: + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + self.conv_fg = zero_module( + conv_nd(dims, in_channels_fg, model_channels, 3, padding=1) + ) if exists(in_channels_fg) else None + self.conv_bg = zero_module( + conv_nd(dims, in_channels_bg, model_channels, 3, padding=1) + ) if exists(in_channels_bg) else None + + def forward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + + h = self._forward(x, emb, *args, **kwargs) + return self.out(h).to(x.dtype) + + def _forward( + self, + x, + emb, + control = None, + context = None, + mask = None, + **additional_context + ): + hs = [] + h = x.to(self.dtype) + + for module in self.input_blocks: + h = module(h, emb, context, mask, **additional_context) + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, mask, **additional_context) + return h + + +class DualCondUNetXL(UNetModel): + def __init__( + self, + hint_encoder_index = (0, 3, 6, 8), + hint_decoder_index = (), + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.hint_encoder_index = hint_encoder_index + self.hint_decoder_index = hint_decoder_index + + def _forward(self, x, emb, concat=None, control=None, context=None, mask=None, **additional_context): + h = x.to(self.dtype) + hs = [] + + if exists(concat): + h = torch.cat([h, concat], 1) + + control_iter = iter(control) + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context, mask, **additional_context) + + if idx in self.hint_encoder_index: + h += next(control_iter) + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for idx, module in enumerate(self.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, mask, **additional_context) + + if idx in self.hint_decoder_index: + h += next(control_iter) + + return h + + +class ReferenceNet(UNetModel): + def __init__(self, *args, **kwargs): + super().__init__(discard_final_layers=True, *args, **kwargs) + + def forward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + self._forward(x, emb, *args, **kwargs) + + def _forward(self, *args, **kwargs): + super()._forward(*args, **kwargs) + return None \ No newline at end of file diff --git a/refnet/modules/unet_old.py b/refnet/modules/unet_old.py new file mode 100644 index 0000000000000000000000000000000000000000..1474fb552a9208918a2ad47e6453a620c90aea59 --- /dev/null +++ b/refnet/modules/unet_old.py @@ -0,0 +1,596 @@ +import torch +import torch.nn as nn + +from functools import partial +from refnet.util import exists +from refnet.modules.transformer import ( + SelfTransformerBlock, + Transformer, + SpatialTransformer, + rearrange +) +from refnet.ldm.openaimodel import ( + timestep_embedding, + conv_nd, + TimestepBlock, + zero_module, + ResBlock, + linear, + Downsample, + Upsample, + normalization, +) + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +def hack_inference_forward(model): + model.forward = InferenceForward.__get__(model, model.__class__) + +def InferenceForward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb).to(self.dtype) + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(emb.device)) + emb = emb.to(self.dtype) + h = self._forward(x, emb, *args, **kwargs) + return self.out(h.to(x.dtype)) + + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, mask=None, **additional_context): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, Transformer): + x = layer(x, context, mask, **additional_context) + else: + x = layer(x) + return x + + + +class UNetModel(nn.Module): + transformers = { + "vanilla": SpatialTransformer, + } + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout = 0, + channel_mult = (1, 2, 4, 8), + conv_resample = True, + dims = 2, + num_classes = None, + use_checkpoint = False, + num_heads = -1, + num_head_channels = -1, + use_scale_shift_norm = False, + resblock_updown = False, + use_spatial_transformer = False, # custom transformer support + transformer_depth = 1, # custom transformer support + context_dim = None, # custom transformer support + disable_self_attentions = None, + disable_cross_attentions = None, + num_attention_blocks = None, + use_linear_in_transformer = False, + adm_in_channels = None, + transformer_type = "vanilla", + map_module = False, + warp_module = False, + style_modulation = False, + ): + super().__init__() + if use_spatial_transformer: + assert exists(context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}' + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + assert num_heads > -1 or num_head_channels > -1, 'Either num_heads or num_head_channels has to be set' + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.num_classes = num_classes + self.model_channels = model_channels + self.dtype = torch.float32 + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = transformer_depth[-1] + time_embed_dim = model_channels * 4 + resblock = partial( + ResBlock, + emb_channels=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + transformer = partial( + self.transformers[transformer_type], + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + disable_self_attn=disable_self_attentions, + disable_cross_attn=disable_cross_attentions, + ) + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [resblock(ch, out_channels=mult * model_channels)] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, current_head_dim) + if not use_spatial_transformer + else transformer( + ch, current_num_heads, current_head_dim, depth=transformer_depth[level] + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append(TimestepEmbedSequential( + resblock(ch, out_channels=out_ch, down=True) if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + )) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + self.middle_block = TimestepEmbedSequential( + resblock(ch), + SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer + else transformer(ch, current_num_heads, current_head_dim, depth=transformer_depth_middle), + resblock(ch), + ) + + self.output_blocks = nn.ModuleList([]) + self.map_modules = nn.ModuleList([]) + self.warp_modules = nn.ModuleList([]) + self.style_modules = nn.ModuleList([]) + + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [resblock(ch + ich, out_channels=model_channels * mult)] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels > -1: + current_num_heads = ch // num_head_channels + current_head_dim = num_head_channels + else: + current_num_heads = num_heads + current_head_dim = ch // num_heads + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, current_head_dim) if not use_spatial_transformer + else transformer( + ch, current_num_heads, current_head_dim, depth=transformer_depth[level] + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + resblock(ch, up=True) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + if map_module: + self.map_modules.append( + SelfTransformerBlock(ich) + ) + + if warp_module: + self.warp_modules.append( + SelfTransformerBlock(ich) + ) + + if style_modulation: + self.style_modules.append(nn.ModuleList([ + nn.LayerNorm(ch*2), + nn.Linear(time_embed_dim, ch*2), + zero_module(nn.Linear(ch*2, ch*2)) + ])) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps=None, y=None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + + h = self._forward(x, emb, *args, **kwargs) + return self.out(h).to(x.dtype) + + def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context): + hs = [] + h = x.to(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, mask, **additional_context) + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, mask, **additional_context) + return h + + +class DualCondUNet(UNetModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hint_encoder_index = [0, 3, 6, 9, 11] + + def _forward(self, x, emb, control=None, context=None, mask=None, **additional_context): + h = x.to(self.dtype) + hs = [] + + control_iter = iter(control) + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context, mask, **additional_context) + + if idx in self.hint_encoder_index: + h += next(control_iter) + hs.append(h) + + h = self.middle_block(h, emb, context, mask, **additional_context) + + for idx, module in enumerate(self.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, mask, **additional_context) + + return h + +class OldUnet(UNetModel): + def __init__(self, c_channels, model_channels, channel_mult, *args, **kwargs): + super().__init__(channel_mult=channel_mult, model_channels=model_channels, *args, **kwargs) + """ + Semantic condition input blocks, implementation from ControlNet. + Paper: Adding Conditional Control to Text-to-Image Diffusion Models + Authors: Lvmin Zhang, Anyi Rao, and Maneesh Agrawala + Code link: https://github.com/lllyasviel/ControlNet + """ + from refnet.modules.encoder import SimpleEncoder, MultiEncoder + # self.semantic_input_blocks = SimpleEncoder(c_channels, model_channels) + self.semantic_input_blocks = MultiEncoder(c_channels, model_channels, channel_mult) + self.hint_encoder_index = [0, 3, 6, 9, 11] + + def forward(self, x, timesteps=None, control=None, context=None, y=None, **kwargs): + concat = control[0].to(self.dtype) + context = context.to(self.dtype) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb).to(self.dtype) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.to(self.dtype) + hints = self.semantic_input_blocks(concat, emb, context) + + for idx, module in enumerate(self.input_blocks): + h = module(h, emb, context) + if idx in self.hint_encoder_index: + h += hints.pop(0) + + hs.append(h) + + h = self.middle_block(h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.to(x.dtype) + return self.out(h) + + +class UNetEncoder(nn.Module): + transformers = { + "vanilla": SpatialTransformer, + } + + def __init__( + self, + in_channels, + model_channels, + num_res_blocks, + attention_resolutions, + dropout = 0, + channel_mult = (1, 2, 4, 8), + conv_resample = True, + dims = 2, + num_classes = None, + use_checkpoint = False, + num_heads = -1, + num_head_channels = -1, + use_scale_shift_norm = False, + resblock_updown = False, + use_spatial_transformer = False, # custom transformer support + transformer_depth = 1, # custom transformer support + context_dim = None, # custom transformer support + disable_self_attentions = None, + disable_cross_attentions = None, + num_attention_blocks = None, + use_linear_in_transformer = False, + adm_in_channels = None, + transformer_type = "vanilla", + style_modulation = False, + ): + super().__init__() + if use_spatial_transformer: + assert exists( + context_dim) or disable_cross_attentions, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert transformer_type in self.transformers.keys(), f'Assigned transformer is not implemented.. Choices: {self.transformers.keys()}' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.style_modulation = style_modulation + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + + time_embed_dim = model_channels * 4 + + resblock = partial( + ResBlock, + emb_channels=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + + transformer = partial( + self.transformers[transformer_type], + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + disable_self_attn=disable_self_attentions, + disable_cross_attn=disable_cross_attentions, + ) + + zero_conv = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) + + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_layers = nn.ModuleList([zero_module( + nn.Linear(model_channels, model_channels * 2) if style_modulation else + zero_conv(model_channels, model_channels) + )]) + + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [resblock(ch, out_channels=mult * model_channels)] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + SelfTransformerBlock(ch, num_head_channels) + if not use_spatial_transformer + else transformer( + ch, num_heads, num_head_channels, depth=transformer_depth[level] + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_layers.append(zero_module( + nn.Linear(ch, ch * 2) if style_modulation else zero_conv(ch, ch) + )) + + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append(TimestepEmbedSequential( + resblock(ch, out_channels=mult * model_channels, down=True) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + )) + self.zero_layers.append(zero_module( + nn.Linear(out_ch, min(model_channels * 8, out_ch * 4)) if style_modulation else + zero_conv(out_ch, out_ch) + )) + ch = out_ch + ds *= 2 + + def forward(self, x, timesteps = None, y = None, *args, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + + hs = self._forward(x, emb, *args, **kwargs) + return hs + + def _forward(self, x, emb, context = None, **additional_context): + hints = [] + h = x.to(self.dtype) + + for zero_layer, module in zip(self.zero_layers, self.input_blocks): + h = module(h, emb, context, **additional_context) + + if self.style_modulation: + hint = zero_layer(h.mean(dim=[2, 3])) + else: + hint = zero_layer(h) + hint = rearrange(hint, "b c h w -> b (h w) c").contiguous() + hints.append(hint) + + hints.reverse() + return hints \ No newline at end of file diff --git a/refnet/sampling/__init__.py b/refnet/sampling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c00b27a8059ac7a481abae9709eaebde64fd73d --- /dev/null +++ b/refnet/sampling/__init__.py @@ -0,0 +1,11 @@ +from .denoiser import CFGDenoiser, DiffuserDenoiser +from .hook import UnetHook, torch_dfs +from .tps_transformation import tps_warp +from .sampler import KDiffusionSampler, kdiffusion_sampler_list +from .scheduler import get_noise_schedulers + +def get_sampler_list(): + sampler_list = [ + "diffuser_" + k for k in DiffuserDenoiser.scheduler_types.keys() + ] + kdiffusion_sampler_list() + return sorted(sampler_list) \ No newline at end of file diff --git a/refnet/sampling/denoiser.py b/refnet/sampling/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7604772c7486e5a22e5ab326d3da6fb2bb0204 --- /dev/null +++ b/refnet/sampling/denoiser.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn + +import inspect +import os.path as osp +from typing import Union, Optional +from tqdm import tqdm +from omegaconf import OmegaConf +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + PNDMScheduler, + LMSDiscreteScheduler, +) + +def exists(v): + return v is not None + + + +class CFGDenoiser(nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model, device): + super().__init__() + denoiser = CompVisDenoiser if model.parameterization == "eps" else CompVisVDenoiser + self.model_wrap = denoiser(model, device=device) + + @property + def inner_model(self): + return self.model_wrap + + def forward( + self, + x, + sigma, + cond: dict, + cond_scale: Union[float, list[float]] + ): + """ + Simplify k-diffusion sampler for sketch colorizaiton. + Available for reference CFG / sketch CFG or Dual CFG + """ + if not isinstance(cond_scale, list): + if cond_scale > 1.: + repeats = 2 + else: + return self.inner_model(x, sigma, cond=cond) + else: + repeats = 3 + + x_in = torch.cat([x] * repeats) + sigma_in = torch.cat([sigma] * repeats) + x_out = self.inner_model(x_in, sigma_in, cond=cond).chunk(repeats) + + if repeats == 2: + x_cond, x_uncond = x_out[:] + return x_uncond + (x_cond - x_uncond) * cond_scale + else: + x_cond, x_uncond_0, x_uncond_1 = x_out[:] + return (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] + + x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5 + + + + +scheduler_config_path = "configs/scheduler_cfgs" +class DiffuserDenoiser: + scheduler_types = { + "ddim": DDIMScheduler, + "dpm": DPMSolverMultistepScheduler, + "dpm_sde": DPMSolverMultistepScheduler, + "pndm": PNDMScheduler, + "lms": LMSDiscreteScheduler + } + def __init__(self, scheduler_type, prediction_type, use_karras=False): + scheduler_type = scheduler_type.replace("diffuser_", "") + assert scheduler_type in self.scheduler_types.keys(), "Selected scheduler is not implemented" + scheduler = self.scheduler_types[scheduler_type] + scheduler_config = OmegaConf.load(osp.abspath(osp.join(scheduler_config_path, scheduler_type + ".yaml"))) + if "use_karras_sigmas" in set(inspect.signature(scheduler).parameters.keys()): + scheduler_config.use_karras_sigmas = use_karras + self.scheduler = scheduler(prediction_type=prediction_type, **scheduler_config) + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def __call__( + self, + x, + cond, + cond_scale, + unet, + timesteps, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + eta: float = 0.0, + device: str = "cuda" + ): + self.scheduler.set_timesteps(timesteps, device=device) + timesteps = self.scheduler.timesteps + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + x_start = x + x = x * self.scheduler.init_noise_sigma + inpaint_latents = cond.pop("inpaint_bg", None) + + if exists(inpaint_latents): + mask = cond.get("mask", None) + threshold = cond.pop("threshold", 0.5) + inpaint_latents = inpaint_latents[0] + assert exists(mask) + mask = mask[0] + mask = torch.where(mask > threshold, torch.ones_like(mask), torch.zeros_like(mask)) + + for i, t in enumerate(tqdm(timesteps)): + x_t = self.scheduler.scale_model_input(x, t) + + if not isinstance(cond_scale, list): + if cond_scale > 1.: + repeats = 2 + else: + repeats = 1 + else: + repeats = 3 + + x_in = torch.cat([x_t] * repeats) + x_out = unet.apply_model( + x_in, + t[None].expand(x_in.shape[0]), + cond=cond + ) + + if repeats == 1: + pred = x_out + + elif repeats == 2: + x_cond, x_uncond = x_out.chunk(2) + pred = x_uncond + (x_cond - x_uncond) * cond_scale + + else: + x_cond, x_uncond_0, x_uncond_1 = x_out.chunk(3) + pred = (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] + + x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5 + + x = self.scheduler.step( + pred, t, x, **extra_step_kwargs, return_dict=False + )[0] + + if exists(inpaint_latents) and exists(mask) and i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = inpaint_latents + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, x_start, torch.tensor([noise_timestep]) + ) + x = (1 - mask) * init_latents_proper + mask * x + + return x \ No newline at end of file diff --git a/refnet/sampling/hook.py b/refnet/sampling/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..940240c7973a6e3e955e34fd0a591e5cea8c2995 --- /dev/null +++ b/refnet/sampling/hook.py @@ -0,0 +1,257 @@ +import torch +import torch.nn as nn + +from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock +from refnet.util import checkpoint_wrapper + +""" + This implementation refers to Multi-ControlNet, thanks for the authors + Paper: Adding Conditional Control to Text-to-Image Diffusion Models + Link: https://github.com/Mikubill/sd-webui-controlnet +""" + +def exists(v): + return v is not None + +def torch_dfs(model: nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + +class AutoMachine(): + Read = "read" + Write = "write" + + +""" + This class controls the attentions of reference unet and denoising unet +""" +class ReferenceAttentionControl: + writer_modules = [] + reader_modules = [] + def __init__( + self, + reader_module, + writer_module, + time_embed_ch = 0, + only_decoder = True, + *args, + **kwargs + ): + self.time_embed_ch = time_embed_ch + self.trainable_layers = [] + self.only_decoder = only_decoder + self.hooked = False + + self.register("read", reader_module) + self.register("write", writer_module) + + if time_embed_ch > 0: + self.insert_time_emb_proj(reader_module) + + def insert_time_emb_proj(self, unet): + for module in torch_dfs(unet.output_blocks if self.only_decoder else unet): + if isinstance(module, BasicTransformerBlock): + module.time_proj = nn.Linear(self.time_embed_ch, module.dim) + self.trainable_layers.append(module.time_proj) + + def register(self, mode, unet): + @checkpoint_wrapper + def transformer_forward_write(self, x, context=None, mask=None, emb=None, **kwargs): + x_in = self.norm1(x) + x = self.attn1(x_in) + x + + if not self.disable_cross_attn: + x = self.attn2(self.norm2(x), context) + x + x = self.ff(self.norm3(x)) + x + + self.bank = x_in + return x + + @checkpoint_wrapper + def transformer_forward_read(self, x, context=None, mask=None, emb=None, **kwargs): + if exists(self.bank): + bank = self.bank + if bank.shape[0] != x.shape[0]: + bank = bank.repeat(x.shape[0], 1, 1) + if hasattr(self, "time_proj"): + bank = bank + self.time_proj(emb).unsqueeze(1) + x_in = self.norm1(x) + + x = self.attn1( + x = x_in, + context = torch.cat([x_in, bank], 1), + mask = mask, + scale_factor = self.scale_factor, + **kwargs + ) + x + + x = self.attn2( + x = self.norm2(x), + context = context, + mask = mask, + scale = self.reference_scale, + scale_factor = self.scale_factor + ) + x + + x = self.ff(self.norm3(x)) + x + else: + x = self.original_forward(x, context, mask, emb) + return x + + assert mode in ["write", "read"] + + if mode == "read": + self.hooked = True + for module in torch_dfs(unet.output_blocks if self.only_decoder else unet): + if isinstance(module, BasicTransformerBlock): + if mode == "write": + module.original_forward = module.forward + module.forward = transformer_forward_write.__get__(module, BasicTransformerBlock) + self.writer_modules.append(module) + else: + if not isinstance(module, SelfInjectedTransformerBlock): + print(f"Hooking transformer block {module.__class__.__name__} for read mode") + module.original_forward = module.forward + module.forward = transformer_forward_read.__get__(module, BasicTransformerBlock) + self.reader_modules.append(module) + + def update(self): + for idx in range(len(self.writer_modules)): + self.reader_modules[idx].bank = self.writer_modules[idx].bank + + def restore(self): + for idx in range(len(self.writer_modules)): + self.writer_modules[idx].forward = self.writer_modules[idx].original_forward + self.reader_modules[idx].forward = self.reader_modules[idx].original_forward + self.reader_modules[idx].bank = None + self.hooked = False + + def clean(self): + for idx in range(len(self.reader_modules)): + self.reader_modules[idx].bank = None + for idx in range(len(self.writer_modules)): + self.writer_modules[idx].bank = None + self.hooked = False + + def reader_restore(self): + for idx in range(len(self.reader_modules)): + self.reader_modules[idx].forward = self.reader_modules[idx].original_forward + self.reader_modules[idx].bank = None + self.hooked = False + + def get_trainable_layers(self): + return self.trainable_layers + + +""" + This class is for self-injection inside the denoising unet +""" +class UnetHook: + def __init__(self): + super().__init__() + self.attention_auto_machine = AutoMachine.Read + + def enhance_reference( + self, + model, + ldm, + bs, + s, + r, + style_cfg=0.5, + control_cfg=0, + gr_indice=None, + injection=False, + start_step=0, + ): + def forward(self, x, t, control, context, **kwargs): + if 1 - t[0] / (ldm.num_timesteps - 1) >= outer.start_step: + # Write + outer.attention_auto_machine = AutoMachine.Write + + rx = ldm.add_noise(outer.r.cpu(), torch.round(t.float()).long().cpu()).cuda().to(x.dtype) + self.original_forward(rx, t, control=outer.s, context=context, **kwargs) + + # Read + outer.attention_auto_machine = AutoMachine.Read + return self.original_forward(x, t, control=control, context=context, **kwargs) + + def hacked_basic_transformer_inner_forward(self, x, context=None, mask=None, emb=None, **kwargs): + x_norm1 = self.norm1(x) + self_attn1 = None + if self.disable_self_attn: + # Do not use self-attention + self_attn1 = self.attn1(x_norm1, context=context, **kwargs) + + else: + # Use self-attention + self_attention_context = x_norm1 + if outer.attention_auto_machine == AutoMachine.Write: + self.bank.append(self_attention_context.detach().clone()) + self.style_cfgs.append(outer.current_style_fidelity) + if outer.attention_auto_machine == AutoMachine.Read: + if len(self.bank) > 0: + style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) + self_attn1_uc = self.attn1( + x_norm1, + context=torch.cat([self_attention_context] + self.bank, dim=1), + **kwargs + ) + self_attn1_c = self_attn1_uc.clone() + if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: + self_attn1_c[outer.current_uc_indices] = self.attn1( + x_norm1[outer.current_uc_indices], + context=self_attention_context[outer.current_uc_indices], + **kwargs + ) + self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc + self.bank = [] + self.style_cfgs = [] + if self_attn1 is None: + self_attn1 = self.attn1(x_norm1, context=self_attention_context) + + x = self_attn1.to(x.dtype) + x + x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor, **kwargs) + x + x = self.ff(self.norm3(x)) + x + return x + + self.s = [s.repeat(bs, 1, 1, 1) * control_cfg for s in ldm.control_encoder(s)] + self.r = r + self.injection = injection + self.start_step = start_step + self.current_uc_indices = gr_indice + self.current_style_fidelity = style_cfg + + outer = self + model = model.diffusion_model + model.original_forward = model.forward + # TODO: change the class name to target + model.forward = forward.__get__(model, model.__class__) + all_modules = torch_dfs(model) + + for module in all_modules: + if isinstance(module, BasicTransformerBlock): + module._unet_hook_original_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.style_cfgs = [] + + + def restore(self, model): + model = model.diffusion_model + if hasattr(model, "original_forward"): + model.forward = model.original_forward + del model.original_forward + + all_modules = torch_dfs(model) + for module in all_modules: + if isinstance(module, BasicTransformerBlock): + if hasattr(module, "_unet_hook_original_forward"): + module.forward = module._unet_hook_original_forward + del module._unet_hook_original_forward + if hasattr(module, "bank"): + module.bank = None + if hasattr(module, "style_cfgs"): + del module.style_cfgs \ No newline at end of file diff --git a/refnet/sampling/manipulation.py b/refnet/sampling/manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1b64e1bf7bc16da54a313fd615cb92dc765c18 --- /dev/null +++ b/refnet/sampling/manipulation.py @@ -0,0 +1,135 @@ +import torch +import torch.nn.functional as F + +import numpy as np + + +def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]): + """ + The shape of input scales tensor should be (b, n, 1) + """ + assert len(s.shape) == 3, len(thresholds) == 4 + maxm = s.max(dim=1, keepdim=True).values + minm = s.min(dim=1, keepdim=True).values + d = maxm - minm + + maxmin = (s - minm) / d + + adjust_scale = torch.where(maxmin <= thresholds[0], + -dscale * ratio, + -dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0])) + adjust_scale = torch.where(maxmin > thresholds[1], + 0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]), + adjust_scale) + adjust_scale = torch.where(maxmin > thresholds[2], + 0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]), + adjust_scale) + adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale) + return adjust_scale + + +def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]): + # print(f"target:{t}, anchor:{a}") + cls_token = v[:, 0].unsqueeze(1) + v = v[:, 1:] + + cur_target_scale = clip.calculate_scale(cls_token, t) + # control_scale = clip.calculate_scale(cls_token, c) + # print(f"current global target scale: {cur_target_scale},", + # f" global control scale: {control_scale}") + + if a is not None and a != "none": + a = [a] * v.shape[0] + a = clip.encode_text(a) + anchor_scale = clip.calculate_scale(cls_token, a) + dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale + # print(f"global anchor scale: {anchor_scale}") + + c_map = clip.calculate_scale(v, c) + a_map = clip.calculate_scale(v, a) + pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale + base = 1 if enhance else 0 + v = v + (pwm + base * a_map) * (t - a) + else: + dscale = target_scale - cur_target_scale + c_map = clip.calculate_scale(v, c) + pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale + v = v + pwm * t + v = torch.cat([cls_token, v], dim=1) + return v + +def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]): + """ + v: visual tokens in shape (b, n, c) + target: target text embeddings in shape (b, 1 ,c) + control: control text embeddings in shape (b, 1, c) + """ + controls, targets = clip.encode_text(controls + targets).chunk(2) + for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list): + v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds) + return v + + +def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False): + if a is not None and a != "none": + a = [a] * v.shape[0] + a = clip.encode_text(a) + if enhance: + s_a = clip.calculate_scale(v, a) + v = v - s_a * a + else: + v = v + target_scale * (t - a) + return v + if enhance: + v = v + target_scale * t + else: + cur_target_scale = clip.calculate_scale(v, t) + v = v + (target_scale - cur_target_scale) * t + return v + + +def global_manipulate(clip, v, targets, target_scales, anchors, enhances): + targets = clip.encode_text(targets) + for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances): + v = global_manipulate_step(clip, v, t, s_t, a, enhance) + return v + + +def assign_heatmap(s: torch.Tensor, threshold: float): + """ + The shape of input scales tensor should be (b, n, 1) + """ + maxm = s.max(dim=1, keepdim=True).values + minm = s.min(dim=1, keepdim=True).values + d = maxm - minm + return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25) + + +def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3, + controls, targets, anchors, thresholds_list, target_scales, enhances): + model.low_vram_shift("cond") + clip = model.cond_stage_model + + v = clip.encode(reference, "full") + if len(targets) > 0: + controls, targets = clip.encode_text(controls + targets).chunk(2) + inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances) + for c, t, a, target_scale, thresholds, enhance in inputs_iter: + # update image tokens + v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds) + token_length = v.shape[1] - 1 + grid_num = int(token_length ** 0.5) + vis_c = clip.encode_text([vis_c]) + local_v = v[:, 1:] + scale = clip.calculate_scale(local_v, vis_c) + scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num) + scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width) + + # calculate heatmaps + heatmaps = [] + for threshold in [ts0, ts1, ts2, ts3]: + heatmap = assign_heatmap(scale, threshold=threshold) + heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy() + heatmap = (heatmap * 255.).astype(np.uint8) + heatmaps.append(heatmap) + return heatmaps \ No newline at end of file diff --git a/refnet/sampling/sampler.py b/refnet/sampling/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3305c9f223d9dc26136301cd3a3534057749ff --- /dev/null +++ b/refnet/sampling/sampler.py @@ -0,0 +1,192 @@ +import dataclasses +import torch +import k_diffusion +import inspect + +from types import SimpleNamespace +from refnet.util import default +from .scheduler import schedulers, schedulers_map +from .denoiser import CFGDenoiser + +defaults = SimpleNamespace(**{ + "eta_ddim": 0.0, + "eta_ancestral": 1.0, + "ddim_discretize": "uniform", + "s_churn": 0.0, + "s_tmin": 0.0, + "s_noise": 1.0, + "k_sched_type": "Automatic", + "sigma_min": 0.0, + "sigma_max": 0.0, + "rho": 0.0, + "eta_noise_seed_delta": 0, + "always_discard_next_to_last_sigma": False, +}) + +@dataclasses.dataclass +class Sampler: + label: str + funcname: str + aliases: any + options: dict + + +samplers_k_diffusion = [ + Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}), + Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}), + Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), + Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), + Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), + Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), + Sampler('Euler', 'sample_euler', ['k_euler'], {}), + Sampler('LMS', 'sample_lms', ['k_lms'], {}), + Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), + Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}), + Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), + Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), + Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}) +] + +sampler_extra_params = { + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_fast': ['s_noise'], + 'sample_dpm_2_ancestral': ['s_noise'], + 'sample_dpmpp_2s_ancestral': ['s_noise'], + 'sample_dpmpp_sde': ['s_noise'], + 'sample_dpmpp_2m_sde': ['s_noise'], + 'sample_dpmpp_3m_sde': ['s_noise'], +} + +def kdiffusion_sampler_list(): + return [k.label for k in samplers_k_diffusion] + + +k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion} +k_diffusion_scheduler = {x.name: x.function for x in schedulers} + +def exists(v): + return v is not None + + +class KDiffusionSampler: + def __init__(self, sampler, scheduler, sd, device): + # k_diffusion_samplers_map[] + self.config = k_diffusion_samplers_map[sampler] + funcname = self.config.funcname + + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname) + self.scheduler_name = scheduler + self.sd = CFGDenoiser(sd, device) + self.model_wrap = self.sd.model_wrap + self.device = device + + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + self.eta_default = 1.0 + self.eta = None + + self.extra_params = [] + + if exists(sd.sigma_max) and exists(sd.sigma_min): + self.model_wrap.sigmas[-1] = sd.sigma_max + self.model_wrap.sigmas[0] = sd.sigma_min + + def initialize(self): + self.eta = getattr(defaults, self.eta_option_field, 0.0) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(self, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(defaults, 's_churn', self.s_churn) + s_tmin = getattr(defaults, 's_tmin', self.s_tmin) + s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(defaults, 's_noise', self.s_noise) + + if 's_churn' in extra_params_kwargs and s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + self.s_churn = s_churn + if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + self.s_tmin = s_tmin + if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + self.s_tmax = s_tmax + if 's_noise' in extra_params_kwargs and s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + self.s_noise = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, seed): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) + + def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + + steps += 1 if discard_next_to_last_sigma else 0 + + if self.scheduler_name == 'Automatic': + self.scheduler_name = self.config.options.get('scheduler', None) + + scheduler = schedulers_map.get(self.scheduler_name) + sigma_min = default(sigmas_min, self.model_wrap.sigma_min) + sigma_max = default(sigmas_max, self.model_wrap.sigma_max) + + if scheduler is None or scheduler.function is None: + sigmas = self.model_wrap.get_sigmas(steps) + else: + sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} + + if scheduler.need_inner_model: + sigmas_kwargs['inner_model'] = self.model_wrap + + sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device) + + if discard_next_to_last_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + return sigmas + + + def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None): + x = x * sigmas[0] + + extra_params_kwargs = self.initialize() + parameters = inspect.signature(self.func).parameters + + if 'n' in parameters: + extra_params_kwargs['n'] = steps + + if 'sigma_min' in parameters: + extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min() + extra_params_kwargs['sigma_max'] = sigmas.max() + + if 'sigmas' in parameters: + extra_params_kwargs['sigmas'] = sigmas + + if self.config.options.get('brownian_noise', False): + noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None + extra_params_kwargs['noise_sampler'] = noise_sampler + + if self.config.options.get('solver_type', None) == 'heun': + extra_params_kwargs['solver_type'] = 'heun' + + return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs) diff --git a/refnet/sampling/scheduler.py b/refnet/sampling/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b2aadb9806ea66dcafa06a0d4ca8d163227a48ca --- /dev/null +++ b/refnet/sampling/scheduler.py @@ -0,0 +1,42 @@ +import torch +import k_diffusion +import dataclasses + +@dataclasses.dataclass +class Scheduler: + name: str + label: str + function: any + + default_rho: float = -1 + need_inner_model: bool = False + aliases: list = None + + +def uniform(n, sigma_min, sigma_max, inner_model, device): + return inner_model.get_sigmas(n) + + +def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): + start = inner_model.sigma_to_t(torch.tensor(sigma_max)) + end = inner_model.sigma_to_t(torch.tensor(sigma_min)) + sigs = [ + inner_model.t_to_sigma(ts) + for ts in torch.linspace(start, end, n + 1)[:-1] + ] + sigs += [0.0] + return torch.FloatTensor(sigs).to(device) + +schedulers = [ + Scheduler('automatic', 'Automatic', None), + Scheduler('uniform', 'Uniform', uniform, need_inner_model=True), + Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0), + Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), + Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), + Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), +] + +def get_noise_schedulers(): + return [scheduler.label for scheduler in schedulers] + +schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} \ No newline at end of file diff --git a/refnet/sampling/tps_transformation.py b/refnet/sampling/tps_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..e43a88d6af4a1340ec608f23e0b78349b3a05cca --- /dev/null +++ b/refnet/sampling/tps_transformation.py @@ -0,0 +1,203 @@ +''' +Calculate warped image using control point manipulation on a thin plate (TPS) +Based on Herve Lombaert's 2006 web article +"Manual Registration with Thin Plates" +(https://profs.etsmtl.ca/hlombaert/thinplates/) + +Implementation by Yucheol Jung +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import PIL.Image as Image +import torchvision.transforms as tf + + +def tps_warp(images, num_points=10, perturbation_strength=10, random=True, pts_before=None, pts_after=None): + if random: + b, c, h, w = images.shape + device, dtype = images.device, images.dtype + pts_before = torch.rand([b, num_points, 2], dtype=dtype, device=device) * torch.Tensor([[[h, w]]]).to(device) + pts_after = pts_before + torch.randn([b, num_points, 2], dtype=dtype, device=device) * perturbation_strength + return _tps_warp(images, pts_before, pts_after) + +def _tps_warp(im, pts_before, pts_after, normalize=True): + ''' + Deforms image according to movement of pts_before and pts_after + + Args) + im torch.Tensor object of size NxCxHxW + pts_before torch.Tensor object of size NxTx2 (T is # control pts) + pts_after torch.Tensor object of size NxTx2 (T is # control pts) + ''' + # check input requirements + assert (4 == im.dim()) + assert (3 == pts_after.dim()) + assert (3 == pts_before.dim()) + N = im.size()[0] + assert (N == pts_after.size()[0] and N == pts_before.size()[0]) + assert (2 == pts_after.size()[2] and 2 == pts_before.size()[2]) + T = pts_after.size()[1] + assert (T == pts_before.size()[1]) + H = im.size()[2] + W = im.size()[3] + + if normalize: + pts_after = pts_after.clone() + pts_after[:, :, 0] /= 0.5 * W + pts_after[:, :, 1] /= 0.5 * H + pts_after -= 1 + pts_before = pts_before.clone() + pts_before[:, :, 0] /= 0.5 * W + pts_before[:, :, 1] /= 0.5 * H + pts_before -= 1 + + def construct_P(): + ''' + Consturcts matrix P of size NxTx3 where + P[n,i,0] := 1 + P[n,i,1:] := pts_after[n] + ''' + # Create matrix P with same configuration as 'pts_after' + P = pts_after.new_zeros((N, T, 3)) + P[:, :, 0] = 1 + P[:, :, 1:] = pts_after + + return P + + def calc_U(pt1, pt2): + ''' + Calculate distance U between pt1 and pt2 + + U(r) := r**2 * log(r) + where + r := |pt1 - pt2|_2 + + Args) + pt1 torch.Tensor object, last dim is always 2 + pt2 torch.Tensor object, last dim is always 2 + ''' + assert (2 == pt1.size()[-1]) + assert (2 == pt2.size()[-1]) + + diff = pt1 - pt2 + sq_diff = diff ** 2 + sq_diff_sum = sq_diff.sum(-1) + r = sq_diff_sum.sqrt() + + # Adds 1e-6 for numerical stability + return (r ** 2) * torch.log(r + 1e-6) + + def construct_K(): + ''' + Consturcts matrix K of size NxTxT where + K[n,i,j] := U(|pts_after[n,i] - pts_after[n,j]|_2) + ''' + + # Assuming the number of control points are small enough, + # We just use for-loop for easy-to-read code + + # Create matrix K with same configuration as 'pts_after' + K = pts_after.new_zeros((N, T, T)) + for i in range(T): + for j in range(T): + K[:, i, j] = calc_U(pts_after[:, i, :], pts_after[:, j, :]) + + return K + + def construct_L(): + ''' + Consturcts matrix L of size Nx(T+3)x(T+3) where + L[n] = [[ K[n] P[n] ]] + [[ P[n]^T 0 ]] + ''' + P = construct_P() + K = construct_K() + + # Create matrix L with same configuration as 'K' + L = K.new_zeros((N, T + 3, T + 3)) + + # Fill L matrix + L[:, :T, :T] = K + L[:, :T, T:(T + 3)] = P + L[:, T:(T + 3), :T] = P.transpose(1, 2) + + return L + + def construct_uv_grid(): + ''' + Returns H x W x 2 tensor uv with UV coordinate as its elements + uv[:,:,0] is H x W grid of x values + uv[:,:,1] is H x W grid of y values + ''' + u_range = torch.arange( + start=-1.0, end=1.0, step=2.0 / W, device=im.device) + assert (W == u_range.size()[0]) + u = u_range.new_zeros((H, W)) + u[:] = u_range + + v_range = torch.arange( + start=-1.0, end=1.0, step=2.0 / H, device=im.device) + assert (H == v_range.size()[0]) + vt = v_range.new_zeros((W, H)) + vt[:] = v_range + v = vt.transpose(0, 1) + + return torch.stack([u, v], dim=2) + + L = construct_L() + VT = pts_before.new_zeros((N, T + 3, 2)) + # Use delta x and delta y as known heights of the surface + VT[:, :T, :] = pts_before - pts_after + + # Solve Lx = VT + # x is of shape (N, T+3, 2) + # x[:,:,0] represents surface parameters for dx surface + # (dx values as surface height (z)) + # x[:,:,1] represents surface parameters for dy surface + # (dy values as surface height (z)) + x = torch.linalg.solve(L, VT) + + uv = construct_uv_grid() + uv_batch = uv.repeat((N, 1, 1, 1)) + + def calc_dxdy(): + ''' + Calculate surface height for each uv coordinate + + Returns NxHxWx2 tensor + ''' + + # control points of size NxTxHxWx2 + cp = uv.new_zeros((H, W, N, T, 2)) + cp[:, :, :] = pts_after + cp = cp.permute([2, 3, 0, 1, 4]) + + U = calc_U(uv, cp) # U value matrix of size NxTxHxW + w, a = x[:, :T, :], x[:, T:, :] # w is of size NxTx2, a is of size Nx3x2 + w_x, w_y = w[:, :, 0], w[:, :, 1] # NxT each + a_x, a_y = a[:, :, 0], a[:, :, 1] # Nx3 each + dx = ( + a_x[:, 0].repeat((H, W, 1)).permute(2, 0, 1) + + torch.einsum('nhwd,nd->nhw', uv_batch, a_x[:, 1:]) + + torch.einsum('nthw,nt->nhw', U, w_x)) # dx values of NxHxW + dy = ( + a_y[:, 0].repeat((H, W, 1)).permute(2, 0, 1) + + torch.einsum('nhwd,nd->nhw', uv_batch, a_y[:, 1:]) + + torch.einsum('nthw,nt->nhw', U, w_y)) # dy values of NxHxW + + return torch.stack([dx, dy], dim=3) + + dxdy = calc_dxdy() + flow_field = uv + dxdy + + return F.grid_sample(im, flow_field.to(im.dtype)) + +if __name__ == '__main__': + num_points = 10 + perturbation_strength = 10 + img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB")).unsqueeze(0) + # img = tf.ToTensor()(Image.open("../../miniset/origin/109281263.jpg").convert("RGB").resize((224, 224))).unsqueeze(0) + img = tps_warp(img, num_points= num_points, perturbation_strength = perturbation_strength).squeeze(0) + img = tf.ToPILImage()(img) + img.show() \ No newline at end of file diff --git a/refnet/util.py b/refnet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..96c430761864053f13e1b494be454d34cd4d6650 --- /dev/null +++ b/refnet/util.py @@ -0,0 +1,200 @@ +import re +import os.path as osp + +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as tf +from torch.utils.checkpoint import checkpoint + +import numpy as np +import itertools +import importlib + +from tqdm import tqdm +from inspect import isfunction +from functools import wraps +from safetensors import safe_open + + + +def exists(x): + return x is not None + +def append_dims(x, target_dims) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def expand_to_batch_size(x, bs): + if isinstance(x, list): + x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x] + else: + x = x.repeat(bs, *([1] * (len(x.shape) - 1))) + return x + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"): + return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode) + +def get_crop_scale(h, w, bgh, bgw): + gen_aspect = w / h + bg_aspect = bgw / bgh + if gen_aspect > bg_aspect: + cw = 1.0 + ch = (h / w) * (bgw / bgh) + else: + ch = 1.0 + cw = (w / h) * (bgh / bgw) + return ch, cw + +def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"): + assert len(x.shape) == 4 + return F.interpolate(x, size=target_size, mode=interpolation_mode) + +def resize_and_crop(x: torch.Tensor, ch, cw, th, tw): + b, c, h, w = x.shape + return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw]) + + +def fitting_weights(model, sd): + n_params = len([name for name, _ in + itertools.chain(model.named_parameters(), + model.named_buffers())]) + for name, param in tqdm( + itertools.chain(model.named_parameters(), + model.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + device = old_param.device + if len(new_shape) == 1: + # Vectorized 1D case + new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]] + elif len(new_shape) >= 2: + # Vectorized 2D case + i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0] + j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1] + + # Use advanced indexing to extract all values at once + new_param = old_param[i_indices, j_indices] + + # Count how many times each old column is used + n_used_old = torch.bincount( + torch.arange(new_shape[1], device=device) % old_shape[1], + minlength=old_shape[1] + ) + + # Map to new shape + n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]] + + # Reshape for broadcasting + n_used_new = n_used_new.reshape(1, new_shape[1]) + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + + # Normalize + new_param = new_param / n_used_new + + sd[name] = new_param + return sd + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] + +def load_weights(path, weights_only=True): + ext = osp.splitext(path)[-1] + assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}" + if ext == ".safetensors": + sd = {} + safe_sd = safe_open(path, framework="pt", device="cpu") + for key in safe_sd.keys(): + sd[key] = safe_sd.get_tensor(key) + else: + sd = torch.load(path, map_location="cpu", weights_only=weights_only) + if "state_dict" in sd.keys(): + sd = sd["state_dict"] + return sd + + +def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()): + keys = list(sd.keys()) + for k in keys: + for ik in delete_keys: + if len(skip_keys) > 0: + for sk in skip_keys: + if re.match(ik, k) is not None and re.match(sk, k) is None: + del sd[k] + else: + if re.match(ik, k) is not None: + del sd[k] + return sd + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def checkpoint_wrapper(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not hasattr(self, 'checkpoint') or self.checkpoint: + def bound_func(*args, **kwargs): + return func(self, *args, **kwargs) + return checkpoint(bound_func, *args, use_reentrant=False, **kwargs) + else: + return func(self, *args, **kwargs) + return wrapper