|
|
| from torchvision.datasets.utils import download_url |
| from ldm.util import instantiate_from_config |
| import torch |
| import os |
| |
| from google.colab import files |
| from IPython.display import Image as ipyimg |
| import ipywidgets as widgets |
| from PIL import Image |
| from numpy import asarray |
| from einops import rearrange, repeat |
| import torch, torchvision |
| from ldm.models.diffusion.ddim import DDIMSampler |
| from ldm.util import ismap |
| import time |
| from omegaconf import OmegaConf |
|
|
|
|
| def download_models(mode): |
|
|
| if mode == "superresolution": |
| |
| url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' |
| url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' |
|
|
| path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' |
| path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' |
|
|
| download_url(url_conf, path_conf) |
| download_url(url_ckpt, path_ckpt) |
|
|
| path_conf = path_conf + '/?dl=1' |
| path_ckpt = path_ckpt + '/?dl=1' |
| return path_conf, path_ckpt |
|
|
| else: |
| raise NotImplementedError |
|
|
|
|
| def load_model_from_config(config, ckpt): |
| print(f"Loading model from {ckpt}") |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| global_step = pl_sd["global_step"] |
| sd = pl_sd["state_dict"] |
| model = instantiate_from_config(config.model) |
| m, u = model.load_state_dict(sd, strict=False) |
| model.cuda() |
| model.eval() |
| return {"model": model}, global_step |
|
|
|
|
| def get_model(mode): |
| path_conf, path_ckpt = download_models(mode) |
| config = OmegaConf.load(path_conf) |
| model, step = load_model_from_config(config, path_ckpt) |
| return model |
|
|
|
|
| def get_custom_cond(mode): |
| dest = "data/example_conditioning" |
|
|
| if mode == "superresolution": |
| uploaded_img = files.upload() |
| filename = next(iter(uploaded_img)) |
| name, filetype = filename.split(".") |
| os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") |
|
|
| elif mode == "text_conditional": |
| w = widgets.Text(value='A cake with cream!', disabled=True) |
| display(w) |
|
|
| with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: |
| f.write(w.value) |
|
|
| elif mode == "class_conditional": |
| w = widgets.IntSlider(min=0, max=1000) |
| display(w) |
| with open(f"{dest}/{mode}/custom.txt", 'w') as f: |
| f.write(w.value) |
|
|
| else: |
| raise NotImplementedError(f"cond not implemented for mode{mode}") |
|
|
|
|
| def get_cond_options(mode): |
| path = "data/example_conditioning" |
| path = os.path.join(path, mode) |
| onlyfiles = [f for f in sorted(os.listdir(path))] |
| return path, onlyfiles |
|
|
|
|
| def select_cond_path(mode): |
| path = "data/example_conditioning" |
| path = os.path.join(path, mode) |
| onlyfiles = [f for f in sorted(os.listdir(path))] |
|
|
| selected = widgets.RadioButtons( |
| options=onlyfiles, |
| description='Select conditioning:', |
| disabled=False |
| ) |
| display(selected) |
| selected_path = os.path.join(path, selected.value) |
| return selected_path |
|
|
|
|
| def get_cond(mode, selected_path): |
| example = dict() |
| if mode == "superresolution": |
| up_f = 4 |
| visualize_cond_img(selected_path) |
|
|
| c = Image.open(selected_path) |
| c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) |
| c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) |
| c_up = rearrange(c_up, '1 c h w -> 1 h w c') |
| c = rearrange(c, '1 c h w -> 1 h w c') |
| c = 2. * c - 1. |
|
|
| c = c.to(torch.device("cuda")) |
| example["LR_image"] = c |
| example["image"] = c_up |
|
|
| return example |
|
|
|
|
| def visualize_cond_img(path): |
| display(ipyimg(filename=path)) |
|
|
|
|
| def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None): |
|
|
| example = get_cond(task, selected_path) |
|
|
| save_intermediate_vid = False |
| n_runs = 1 |
| masked = False |
| guider = None |
| ckwargs = None |
| mode = 'ddim' |
| ddim_use_x0_pred = False |
| temperature = 1. |
| eta = 1. |
| make_progrow = True |
| custom_shape = None |
|
|
| height, width = example["image"].shape[1:3] |
| split_input = height >= 128 and width >= 128 |
|
|
| if split_input: |
| ks = 128 |
| stride = 64 |
| vqf = 4 |
| model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), |
| "vqf": vqf, |
| "patch_distributed_vq": True, |
| "tie_braker": False, |
| "clip_max_weight": 0.5, |
| "clip_min_weight": 0.01, |
| "clip_max_tie_weight": 0.5, |
| "clip_min_tie_weight": 0.01} |
| else: |
| if hasattr(model, "split_input_params"): |
| delattr(model, "split_input_params") |
|
|
| invert_mask = False |
|
|
| x_T = None |
| for n in range(n_runs): |
| if custom_shape is not None: |
| x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) |
| x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) |
|
|
| logs = make_convolutional_sample(example, model, |
| mode=mode, custom_steps=custom_steps, |
| eta=eta, swap_mode=False , masked=masked, |
| invert_mask=invert_mask, quantize_x0=False, |
| custom_schedule=None, decode_interval=10, |
| resize_enabled=resize_enabled, custom_shape=custom_shape, |
| temperature=temperature, noise_dropout=0., |
| corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, |
| make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred |
| ) |
| return logs |
|
|
|
|
| @torch.no_grad() |
| def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, |
| mask=None, x0=None, quantize_x0=False, img_callback=None, |
| temperature=1., noise_dropout=0., score_corrector=None, |
| corrector_kwargs=None, x_T=None, log_every_t=None |
| ): |
|
|
| ddim = DDIMSampler(model) |
| bs = shape[0] |
| shape = shape[1:] |
| print(f"Sampling with eta = {eta}; steps: {steps}") |
| samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, |
| normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, |
| mask=mask, x0=x0, temperature=temperature, verbose=False, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, x_T=x_T) |
|
|
| return samples, intermediates |
|
|
|
|
| @torch.no_grad() |
| def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, |
| invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, |
| resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, |
| corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): |
| log = dict() |
|
|
| z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, |
| return_first_stage_outputs=True, |
| force_c_encode=not (hasattr(model, 'split_input_params') |
| and model.cond_stage_key == 'coordinates_bbox'), |
| return_original_cond=True) |
|
|
| log_every_t = 1 if save_intermediate_vid else None |
|
|
| if custom_shape is not None: |
| z = torch.randn(custom_shape) |
| print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") |
|
|
| z0 = None |