| """SAMPLING ONLY.""" |
| import os |
| import torch |
| from torch import nn |
| import torchvision |
| import numpy as np |
| from tqdm import tqdm |
| from functools import partial |
| from PIL import Image |
| import shutil |
|
|
| from ldm.modules.diffusionmodules.util import ( |
| make_ddim_sampling_parameters, |
| make_ddim_timesteps, |
| noise_like, |
| ) |
| import clip |
| from einops import rearrange |
| import random |
|
|
|
|
| class VGGPerceptualLoss(torch.nn.Module): |
| def __init__(self, resize=True): |
| super(VGGPerceptualLoss, self).__init__() |
| blocks = [] |
| blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) |
| blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) |
| blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) |
| blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) |
| for bl in blocks: |
| for p in bl.parameters(): |
| p.requires_grad = False |
| self.blocks = torch.nn.ModuleList(blocks) |
| self.transform = torch.nn.functional.interpolate |
| self.resize = resize |
| self.register_buffer( |
| "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| ) |
| self.register_buffer( |
| "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
| ) |
|
|
| def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): |
| input = (input - self.mean) / self.std |
| target = (target - self.mean) / self.std |
| if self.resize: |
| input = self.transform( |
| input, mode="bilinear", size=(224, 224), align_corners=False |
| ) |
| target = self.transform( |
| target, mode="bilinear", size=(224, 224), align_corners=False |
| ) |
| loss = 0.0 |
| x = input |
| y = target |
| for i, block in enumerate(self.blocks): |
| x = block(x) |
| y = block(y) |
| if i in feature_layers: |
| loss += torch.nn.functional.l1_loss(x, y) |
| if i in style_layers: |
| act_x = x.reshape(x.shape[0], x.shape[1], -1) |
| act_y = y.reshape(y.shape[0], y.shape[1], -1) |
| gram_x = act_x @ act_x.permute(0, 2, 1) |
| gram_y = act_y @ act_y.permute(0, 2, 1) |
| loss += torch.nn.functional.l1_loss(gram_x, gram_y) |
| return loss |
|
|
|
|
| class DCLIPLoss(torch.nn.Module): |
| def __init__(self): |
| super(DCLIPLoss, self).__init__() |
| self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") |
| self.upsample = torch.nn.Upsample(scale_factor=7) |
| self.avg_pool = torch.nn.AvgPool2d(kernel_size=16) |
|
|
| def forward(self, image1, image2, text1, text2): |
| text1 = clip.tokenize([text1]).to("cuda") |
| text2 = clip.tokenize([text2]).to("cuda") |
| image1 = image1.unsqueeze(0).cuda() |
| image2 = image2.unsqueeze(0) |
| image1 = self.avg_pool(self.upsample(image1)) |
| image2 = self.avg_pool(self.upsample(image2)) |
| image1_feat = self.model.encode_image(image1) |
| image2_feat = self.model.encode_image(image2) |
| text1_feat = self.model.encode_text(text1) |
| text2_feat = self.model.encode_text(text2) |
| d_image_feat = image1_feat - image2_feat |
| d_text_feat = text1_feat - text2_feat |
| similarity = torch.nn.CosineSimilarity()(d_image_feat, d_text_feat) |
| return 1 - similarity |
|
|
|
|
| class PLMSSampler(object): |
| def __init__(self, model, schedule="linear", **kwargs): |
| super().__init__() |
| self.model = model |
| self.ddpm_num_timesteps = model.num_timesteps |
| self.schedule = schedule |
|
|
| def register_buffer(self, name, attr): |
| if type(attr) == torch.Tensor: |
| if attr.device != torch.device("cuda"): |
| attr = attr.to(torch.device("cuda")) |
| setattr(self, name, attr) |
|
|
| def make_schedule( |
| self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True |
| ): |
| if ddim_eta != 0: |
| raise ValueError("ddim_eta must be 0 for PLMS") |
| self.ddim_timesteps = make_ddim_timesteps( |
| ddim_discr_method=ddim_discretize, |
| num_ddim_timesteps=ddim_num_steps, |
| num_ddpm_timesteps=self.ddpm_num_timesteps, |
| verbose=verbose, |
| ) |
| alphas_cumprod = self.model.alphas_cumprod |
| assert ( |
| alphas_cumprod.shape[0] == self.ddpm_num_timesteps |
| ), "alphas have to be defined for each timestep" |
| to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) |
|
|
| self.register_buffer("betas", to_torch(self.model.betas)) |
| self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) |
| self.register_buffer( |
| "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) |
| ) |
|
|
| |
| self.register_buffer( |
| "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) |
| ) |
| self.register_buffer( |
| "sqrt_one_minus_alphas_cumprod", |
| to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), |
| ) |
| self.register_buffer( |
| "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) |
| ) |
| self.register_buffer( |
| "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) |
| ) |
| self.register_buffer( |
| "sqrt_recipm1_alphas_cumprod", |
| to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), |
| ) |
|
|
| |
| ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( |
| alphacums=alphas_cumprod.cpu(), |
| ddim_timesteps=self.ddim_timesteps, |
| eta=0.0, |
| verbose=verbose, |
| ) |
| self.register_buffer("ddim_sigmas", ddim_sigmas) |
| self.register_buffer("ddim_alphas", ddim_alphas) |
| self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) |
| self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) |
| sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( |
| (1 - self.alphas_cumprod_prev) |
| / (1 - self.alphas_cumprod) |
| * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) |
| ) |
| self.register_buffer( |
| "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps |
| ) |
|
|
| @torch.no_grad() |
| def sample(self, |
| S, |
| batch_size, |
| shape, |
| conditioning=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0., |
| mask=None, |
| x0=None, |
| temperature=1., |
| noise_dropout=0., |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1., |
| unconditional_conditioning=None, |
| |
| dynamic_threshold=None, |
| **kwargs |
| ): |
| if conditioning is not None: |
| if isinstance(conditioning, dict): |
| cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] |
| if cbs != batch_size: |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") |
| else: |
| if conditioning.shape[0] != batch_size: |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| print(f'Data shape for PLMS sampling is {size}') |
|
|
| samples, intermediates = self.plms_sampling(conditioning, size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| ) |
| return samples, intermediates |
| |
| @torch.no_grad() |
| def plms_sampling( |
| self, |
| cond, |
| shape, |
| x_T=None, |
| ddim_use_original_steps=False, |
| callback=None, |
| timesteps=None, |
| quantize_denoised=False, |
| mask=None, |
| x0=None, |
| img_callback=None, |
| log_every_t=100, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| ): |
| device = self.model.betas.device |
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
|
|
| if timesteps is None: |
| timesteps = ( |
| self.ddpm_num_timesteps |
| if ddim_use_original_steps |
| else self.ddim_timesteps |
| ) |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = ( |
| int( |
| min(timesteps / self.ddim_timesteps.shape[0], 1) |
| * self.ddim_timesteps.shape[0] |
| ) |
| - 1 |
| ) |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| intermediates = {"x_inter": [img], "pred_x0": [img]} |
| time_range = ( |
| list(reversed(range(0, timesteps))) |
| if ddim_use_original_steps |
| else np.flip(timesteps) |
| ) |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| print(f"Running PLMS Sampling with {total_steps} timesteps") |
|
|
| iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) |
| old_eps = [] |
|
|
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| if mask is not None: |
| assert x0 is not None |
| |
| img_orig = self.model.q_sample( |
| x0, ts |
| ) |
| img = img_orig * mask + (1.0 - mask) * img |
|
|
| outs = self.p_sample_plms( |
| img, |
| cond, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| ) |
| img, pred_x0, e_t = outs |
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| if callback: |
| callback(i) |
| if img_callback: |
| img_callback(pred_x0, i) |
|
|
| if index % 1 == 0 or index == total_steps - 1: |
| intermediates["x_inter"].append(img) |
| intermediates["pred_x0"].append(pred_x0) |
|
|
| return img, intermediates |
|
|
| @torch.no_grad() |
| def p_sample_plms( |
| self, |
| x, |
| c, |
| t, |
| index, |
| repeat_noise=False, |
| use_original_steps=False, |
| quantize_denoised=False, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| old_eps=None, |
| t_next=None, |
| ): |
| b, *_, device = *x.shape, x.device |
|
|
| def get_model_output(x, t): |
| if ( |
| unconditional_conditioning is None |
| or unconditional_guidance_scale == 1.0 |
| ): |
| e_t = self.model.apply_model(x, t, c) |
| else: |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| if isinstance(c, dict): |
| c_in = {key: [torch.cat([unconditional_conditioning[key][0], c[key][0]])] for key in c} |
| else: |
| c_in = torch.cat([unconditional_conditioning, c]) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
|
|
| if score_corrector is not None: |
| assert self.model.parameterization == "eps" |
| e_t = score_corrector.modify_score( |
| self.model, e_t, x, t, c, **corrector_kwargs |
| ) |
|
|
| return e_t |
|
|
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| alphas_prev = ( |
| self.model.alphas_cumprod_prev |
| if use_original_steps |
| else self.ddim_alphas_prev |
| ) |
| sqrt_one_minus_alphas = ( |
| self.model.sqrt_one_minus_alphas_cumprod |
| if use_original_steps |
| else self.ddim_sqrt_one_minus_alphas |
| ) |
| sigmas = ( |
| self.model.ddim_sigmas_for_original_num_steps |
| if use_original_steps |
| else self.ddim_sigmas |
| ) |
|
|
| def get_x_prev_and_pred_x0(e_t, index): |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full( |
| (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device |
| ) |
|
|
| |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| if quantize_denoised: |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| |
| dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t |
| noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature |
| if noise_dropout > 0.0: |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| return x_prev, pred_x0 |
|
|
| e_t = get_model_output(x, t) |
| if len(old_eps) == 0: |
| |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) |
| e_t_next = get_model_output(x_prev, t_next) |
| e_t_prime = (e_t + e_t_next) / 2 |
| elif len(old_eps) == 1: |
| |
| e_t_prime = (3 * e_t - old_eps[-1]) / 2 |
| elif len(old_eps) == 2: |
| |
| e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 |
| elif len(old_eps) >= 3: |
| |
| e_t_prime = ( |
| 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] |
| ) / 24 |
|
|
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) |
|
|
| return x_prev, pred_x0, e_t |
|
|
| |
|
|
| |
|
|
| @torch.no_grad() |
| def sample_encode_save_noise( |
| self, |
| S, |
| batch_size, |
| shape, |
| conditioning=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0.0, |
| mask=None, |
| x0=None, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| |
| **kwargs, |
| ): |
| assert conditioning is not None |
| |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| if verbose: |
| print(f"Data shape for PLMS sampling is {size}") |
|
|
| samples, intermediates, x0_loop = self.plms_sampling_enc_save_noise( |
| conditioning, |
| size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, |
| x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| input_image=input_image, |
| noise_save_path=noise_save_path, |
| verbose=verbose |
| ) |
| return samples, intermediates, x0_loop |
|
|
| @torch.no_grad() |
| def plms_sampling_enc_save_noise( |
| self, |
| cond, |
| shape, |
| x_T=None, |
| ddim_use_original_steps=False, |
| callback=None, |
| timesteps=None, |
| quantize_denoised=False, |
| mask=None, |
| x0=None, |
| img_callback=None, |
| log_every_t=100, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| verbose=True, |
| ): |
| device = self.model.betas.device |
|
|
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
|
|
| if timesteps is None: |
| timesteps = ( |
| self.ddpm_num_timesteps |
| if ddim_use_original_steps |
| else self.ddim_timesteps |
| ) |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = ( |
| int( |
| min(timesteps / self.ddim_timesteps.shape[0], 1) |
| * self.ddim_timesteps.shape[0] |
| ) |
| - 1 |
| ) |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| intermediates = {"x_inter": [img], "pred_x0": [img]} |
| time_range = ( |
| list(reversed(range(0, timesteps))) |
| if ddim_use_original_steps |
| else np.flip(timesteps) |
| ) |
| time_range = list(range(0, timesteps)) if ddim_use_original_steps else timesteps |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| if verbose: |
| print(f"Running PLMS Sampling with {total_steps} timesteps") |
| iterator = tqdm(time_range[:-1], desc='PLMS Sampler', total=total_steps) |
| else: |
| iterator = time_range[:-1] |
| old_eps = [] |
| noise_images = [] |
| for each_time in time_range: |
| noised_image = self.model.q_sample( |
| input_image, torch.tensor([each_time]).to(device) |
| ) |
| noise_images.append(noised_image) |
| |
| |
| x0_loop = input_image.clone() |
| alphas = ( |
| self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas |
| ) |
| alphas_prev = ( |
| self.model.alphas_cumprod_prev |
| if ddim_use_original_steps |
| else self.ddim_alphas_prev |
| ) |
| sqrt_one_minus_alphas = ( |
| self.model.sqrt_one_minus_alphas_cumprod |
| if ddim_use_original_steps |
| else self.ddim_sqrt_one_minus_alphas |
| ) |
| sigmas = ( |
| self.model.ddim_sigmas_for_original_num_steps |
| if ddim_use_original_steps |
| else self.ddim_sigmas |
| ) |
|
|
| def get_model_output(x, t): |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| if isinstance(cond, dict): |
| c_in = {key: [torch.cat([unconditional_conditioning[key][0], cond[key][0]])] for key in cond} |
| else: |
| c_in = torch.cat([unconditional_conditioning, cond]) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
| return e_t |
|
|
| def get_x_prev_and_pred_x0(e_t, index, curr_x0): |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full( |
| (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device |
| ) |
|
|
| |
| pred_x0 = (curr_x0 - sqrt_one_minus_at * e_t) / a_t.sqrt() |
|
|
| a_t = torch.full((b, 1, 1, 1), alphas[index + 1], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index + 1], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index + 1], device=device) |
| sqrt_one_minus_at = torch.full( |
| (b, 1, 1, 1), sqrt_one_minus_alphas[index + 1], device=device |
| ) |
|
|
| dir_xt = (1.0 - a_t - sigma_t ** 2).sqrt() * e_t |
|
|
| x_prev = a_t.sqrt() * pred_x0 + dir_xt |
|
|
| return x_prev, pred_x0 |
|
|
| for i, step in enumerate(iterator): |
| index = i |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
| e_t = get_model_output(x0_loop, ts) |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index, x0_loop) |
| x0_loop = x_prev |
| |
|
|
| |
| img = x0_loop.clone() |
| time_range = ( |
| list(reversed(range(0, timesteps))) |
| if ddim_use_original_steps |
| else np.flip(timesteps) |
| ) |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| if verbose: |
| print(f"Running PLMS Sampling with {total_steps} timesteps") |
| iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps, miniters=total_steps+1, mininterval=600) |
| else: |
| iterator = time_range |
| old_eps = [] |
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| if mask is not None: |
| assert x0 is not None |
| img_orig = self.model.q_sample( |
| x0, ts |
| ) |
| img = img_orig * mask + (1.0 - mask) * img |
|
|
| outs = self.p_sample_plms_dec_save_noise( |
| img, |
| cond, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| input_image=input_image, |
| noise_save_path=noise_save_path, |
| noise_image=noise_images.pop(), |
| ) |
| img, pred_x0, e_t = outs |
|
|
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| if callback: |
| callback(i) |
| if img_callback: |
| img_callback(pred_x0, i) |
|
|
| if index % log_every_t == 0 or index == total_steps - 1: |
| intermediates["x_inter"].append(img) |
| intermediates["pred_x0"].append(pred_x0) |
|
|
| return img, intermediates, x0_loop |
|
|
| @torch.no_grad() |
| def p_sample_plms_dec_save_noise( |
| self, |
| x, |
| c1, |
| t, |
| index, |
| repeat_noise=False, |
| use_original_steps=False, |
| quantize_denoised=False, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| old_eps=None, |
| t_next=None, |
| input_image=None, |
| noise_save_path=None, |
| noise_image=None, |
| ): |
| b, *_, device = *x.shape, x.device |
|
|
| def get_model_output(x, t): |
| if ( |
| unconditional_conditioning is None |
| or unconditional_guidance_scale == 1.0 |
| ): |
| e_t = self.model.apply_model(x, t, c1) |
| else: |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| if isinstance(c1, dict): |
| c_in = {key: [torch.cat([unconditional_conditioning[key][0], c1[key][0]])] for key in c1} |
| else: |
| c_in = torch.cat([unconditional_conditioning, c1]) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
| return e_t |
|
|
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| alphas_prev = ( |
| self.model.alphas_cumprod_prev |
| if use_original_steps |
| else self.ddim_alphas_prev |
| ) |
| sqrt_one_minus_alphas = ( |
| self.model.sqrt_one_minus_alphas_cumprod |
| if use_original_steps |
| else self.ddim_sqrt_one_minus_alphas |
| ) |
| sigmas = ( |
| self.model.ddim_sigmas_for_original_num_steps |
| if use_original_steps |
| else self.ddim_sigmas |
| ) |
|
|
| def get_x_prev_and_pred_x0(e_t, index): |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full( |
| (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device |
| ) |
|
|
| |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| if quantize_denoised: |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| |
| dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t |
| time_curr = index * 20 + 1 |
| |
| img_prev = noise_image |
| noise = img_prev - a_prev.sqrt() * pred_x0 - dir_xt |
| |
|
|
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| return x_prev, pred_x0 |
|
|
| e_t = get_model_output(x, t) |
| if len(old_eps) == 0: |
| |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) |
| e_t_next = get_model_output(x_prev, t_next) |
| e_t_prime = (e_t + e_t_next) / 2 |
| elif len(old_eps) == 1: |
| |
| e_t_prime = (3 * e_t - old_eps[-1]) / 2 |
| elif len(old_eps) == 2: |
| |
| e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 |
| elif len(old_eps) >= 3: |
| |
| e_t_prime = ( |
| 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] |
| ) / 24 |
|
|
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) |
|
|
| return x_prev, pred_x0, e_t |
|
|
| |
|
|
| def p_sample_plms_sampling( |
| self, |
| x, |
| c1, |
| c2, |
| t, |
| index, |
| repeat_noise=False, |
| use_original_steps=False, |
| quantize_denoised=False, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| old_eps=None, |
| t_next=None, |
| input_image=None, |
| optimizing_weight=None, |
| noise_save_path=None, |
| ): |
| b, *_, device = *x.shape, x.device |
|
|
| def optimize_model_output(x, t): |
| |
| |
| condition = optimizing_weight * c1 + (1 - optimizing_weight) * c2 |
| if ( |
| unconditional_conditioning is None |
| or unconditional_guidance_scale == 1.0 |
| ): |
| e_t = self.model.apply_model(x, t, condition) |
| else: |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| if isinstance(condition, dict): |
| c_in = {key: [torch.cat([unconditional_conditioning[key][0], condition[key][0]])] for key in condition} |
| else: |
| c_in = torch.cat([unconditional_conditioning, condition]) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
| return e_t |
|
|
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| alphas_prev = ( |
| self.model.alphas_cumprod_prev |
| if use_original_steps |
| else self.ddim_alphas_prev |
| ) |
| sqrt_one_minus_alphas = ( |
| self.model.sqrt_one_minus_alphas_cumprod |
| if use_original_steps |
| else self.ddim_sqrt_one_minus_alphas |
| ) |
| sigmas = ( |
| self.model.ddim_sigmas_for_original_num_steps |
| if use_original_steps |
| else self.ddim_sigmas |
| ) |
|
|
| def get_x_prev_and_pred_x0(e_t, index): |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full( |
| (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device |
| ) |
|
|
| |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| if quantize_denoised: |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| |
| dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t |
| time_curr = index * 20 + 1 |
| if noise_save_path and index > 16: |
| noise = torch.load(noise_save_path + "_time%d.pt" % (time_curr))[:1] |
| else: |
| noise = torch.zeros_like(dir_xt) |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| return x_prev, pred_x0 |
|
|
| e_t = optimize_model_output(x, t) |
| if len(old_eps) == 0: |
| |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) |
| |
| e_t_next = optimize_model_output(x_prev, t_next) |
| e_t_prime = (e_t + e_t_next) / 2 |
| elif len(old_eps) == 1: |
| |
| e_t_prime = (3 * e_t - old_eps[-1]) / 2 |
| elif len(old_eps) == 2: |
| |
| e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 |
| elif len(old_eps) >= 3: |
| |
| e_t_prime = ( |
| 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] |
| ) / 24 |
|
|
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) |
|
|
| return x_prev, pred_x0, e_t |
|
|
| |
|
|
| def sample_optimize_intrinsic_edit( |
| self, |
| S, |
| batch_size, |
| shape, |
| conditioning1=None, |
| conditioning2=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0.0, |
| mask=None, |
| x0=None, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| lambda_t=None, |
| lambda_save_path=None, |
| image_save_path=None, |
| original_text=None, |
| new_text=None, |
| otext=None, |
| noise_saved_path=None, |
| |
| **kwargs, |
| ): |
| assert conditioning1 is not None |
| assert conditioning2 is not None |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| print(f"Data shape for PLMS sampling is {size}") |
|
|
| self.plms_sampling_optimize_intrinsic_edit( |
| conditioning1, |
| conditioning2, |
| size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, |
| x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| input_image=input_image, |
| noise_save_path=noise_save_path, |
| lambda_t=lambda_t, |
| lambda_save_path=lambda_save_path, |
| image_save_path=image_save_path, |
| original_text=original_text, |
| new_text=new_text, |
| otext=otext, |
| noise_saved_path=noise_saved_path, |
| ) |
| return None |
|
|
| def plms_sampling_optimize_intrinsic_edit( |
| self, |
| cond1, |
| cond2, |
| shape, |
| x_T=None, |
| ddim_use_original_steps=False, |
| callback=None, |
| timesteps=None, |
| quantize_denoised=False, |
| mask=None, |
| x0=None, |
| img_callback=None, |
| log_every_t=100, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| lambda_t=None, |
| lambda_save_path=None, |
| image_save_path=None, |
| original_text=None, |
| new_text=None, |
| otext=None, |
| noise_saved_path=None, |
| ): |
| |
| device = self.model.betas.device |
|
|
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
| img_clone = img.clone() |
|
|
| if timesteps is None: |
| timesteps = ( |
| self.ddpm_num_timesteps |
| if ddim_use_original_steps |
| else self.ddim_timesteps |
| ) |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = ( |
| int( |
| min(timesteps / self.ddim_timesteps.shape[0], 1) |
| * self.ddim_timesteps.shape[0] |
| ) |
| - 1 |
| ) |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| intermediates = {"x_inter": [img], "pred_x0": [img]} |
| time_range = ( |
| list(reversed(range(0, timesteps))) |
| if ddim_use_original_steps |
| else np.flip(timesteps) |
| ) |
|
|
| weighting_parameter = lambda_t |
| weighting_parameter.requires_grad = True |
| from torch import optim |
|
|
| optimizer = optim.Adam([weighting_parameter], lr=0.05) |
|
|
| print("Original image") |
| with torch.no_grad(): |
| img = img_clone.clone() |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| iterator = time_range |
| old_eps = [] |
|
|
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| outs = self.p_sample_plms_sampling( |
| img, |
| cond1, |
| cond2, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| input_image=input_image, |
| optimizing_weight=torch.ones(50)[i], |
| noise_save_path=noise_saved_path, |
| ) |
| img, pred_x0, e_t = outs |
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| img_temp = self.model.decode_first_stage(img) |
| img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) |
| img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2) |
| |
| with torch.no_grad(): |
| x_sample = 255.0 * rearrange( |
| img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" |
| ) |
| imgsave = Image.fromarray(x_sample.astype(np.uint8)) |
| imgsave.save(image_save_path + "original.png") |
| readed_image = ( |
| torchvision.io.read_image(image_save_path + "original.png").float() |
| / 255 |
| ) |
| print("Optimizing start") |
| for epoch in tqdm(range(10)): |
| img = img_clone.clone() |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| iterator = time_range |
| old_eps = [] |
|
|
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| outs = self.p_sample_plms_sampling( |
| img, |
| cond1, |
| cond2, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| input_image=input_image, |
| optimizing_weight=weighting_parameter[i], |
| noise_save_path=noise_saved_path, |
| ) |
| img, pred_x0, e_t = outs |
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| img_temp = self.model.decode_first_stage(img) |
| img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) |
| img_temp_ddim = img_temp_ddim.cpu() |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image) |
| loss2 = DCLIPLoss()( |
| readed_image, img_temp_ddim[0].float().cuda(), otext, new_text |
| ) |
| loss = 0.05 * loss1 + loss2 |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| |
| |
| if epoch < 9: |
| del img |
| else: |
| |
| with torch.no_grad(): |
| x_sample = 255.0 * rearrange( |
| img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" |
| ) |
| imgsave = Image.fromarray(x_sample.astype(np.uint8)) |
| imgsave.save(image_save_path + "/final.png") |
| torch.save( |
| weighting_parameter, lambda_save_path + "/weightingParam_final.pt" |
| ) |
|
|
| torch.cuda.empty_cache() |
| |
| return None |
|
|
| |
|
|
| |
|
|
| def sample_optimize_intrinsic( |
| self, |
| S, |
| batch_size, |
| shape, |
| conditioning1=None, |
| conditioning2=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0.0, |
| mask=None, |
| x0=None, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| lambda_t=None, |
| lambda_save_path=None, |
| image_save_path=None, |
| original_text=None, |
| new_text=None, |
| otext=None, |
| |
| **kwargs, |
| ): |
| assert conditioning1 is not None |
| assert conditioning2 is not None |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| print(f"Data shape for PLMS sampling is {size}") |
|
|
| self.plms_sampling_optimize_intrinsic( |
| conditioning1, |
| conditioning2, |
| size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, |
| x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| input_image=input_image, |
| noise_save_path=noise_save_path, |
| lambda_t=lambda_t, |
| lambda_save_path=lambda_save_path, |
| image_save_path=image_save_path, |
| original_text=original_text, |
| new_text=new_text, |
| otext=otext, |
| ) |
| return None |
|
|
| def plms_sampling_optimize_intrinsic( |
| self, |
| cond1, |
| cond2, |
| shape, |
| x_T=None, |
| ddim_use_original_steps=False, |
| callback=None, |
| timesteps=None, |
| quantize_denoised=False, |
| mask=None, |
| x0=None, |
| img_callback=None, |
| log_every_t=100, |
| temperature=1.0, |
| noise_dropout=0.0, |
| score_corrector=None, |
| corrector_kwargs=None, |
| unconditional_guidance_scale=1.0, |
| unconditional_conditioning=None, |
| input_image=None, |
| noise_save_path=None, |
| lambda_t=None, |
| lambda_save_path=None, |
| image_save_path=None, |
| original_text=None, |
| new_text=None, |
| otext=None, |
| ): |
| device = self.model.betas.device |
|
|
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
| img_clone = img.clone() |
|
|
| if timesteps is None: |
| timesteps = ( |
| self.ddpm_num_timesteps |
| if ddim_use_original_steps |
| else self.ddim_timesteps |
| ) |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = ( |
| int( |
| min(timesteps / self.ddim_timesteps.shape[0], 1) |
| * self.ddim_timesteps.shape[0] |
| ) |
| - 1 |
| ) |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| time_range = ( |
| list(reversed(range(0, timesteps))) |
| if ddim_use_original_steps |
| else np.flip(timesteps) |
| ) |
| weighting_parameter = lambda_t |
| weighting_parameter.requires_grad = True |
| from torch import optim |
|
|
| optimizer = optim.Adam([weighting_parameter], lr=0.05) |
|
|
| print("Original image") |
| with torch.no_grad(): |
| img = img_clone.clone() |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| iterator = time_range |
| old_eps = [] |
|
|
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| outs = self.p_sample_plms_sampling( |
| img, |
| cond1, |
| cond2, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| input_image=input_image, |
| optimizing_weight=torch.ones(50)[i], |
| noise_save_path=noise_save_path, |
| ) |
| img, pred_x0, e_t = outs |
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| img_temp = self.model.decode_first_stage(img) |
| del img |
| img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) |
| img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2) |
| |
| with torch.no_grad(): |
| x_sample = 255.0 * rearrange( |
| img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" |
| ) |
| imgsave = Image.fromarray(x_sample.astype(np.uint8)) |
| imgsave.save(image_save_path + "original.png") |
|
|
| readed_image = ( |
| torchvision.io.read_image(image_save_path + "original.png").float() |
| / 255 |
| ) |
|
|
| print("Optimizing start") |
| for epoch in tqdm(range(10)): |
| img = img_clone.clone() |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| iterator = time_range |
| old_eps = [] |
|
|
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
| ts_next = torch.full( |
| (b,), |
| time_range[min(i + 1, len(time_range) - 1)], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| outs = self.p_sample_plms_sampling( |
| img, |
| cond1, |
| cond2, |
| ts, |
| index=index, |
| use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, |
| temperature=temperature, |
| noise_dropout=noise_dropout, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| old_eps=old_eps, |
| t_next=ts_next, |
| input_image=input_image, |
| optimizing_weight=weighting_parameter[i], |
| noise_save_path=noise_save_path, |
| ) |
| img, _, e_t = outs |
| old_eps.append(e_t) |
| if len(old_eps) >= 4: |
| old_eps.pop(0) |
| img_temp = self.model.decode_first_stage(img) |
| del img |
| img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) |
| img_temp_ddim = img_temp_ddim.cpu() |
|
|
| |
| |
| |
| |
| |
|
|
| loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image) |
| loss2 = DCLIPLoss()( |
| readed_image, img_temp_ddim[0].float().cuda(), otext, new_text |
| ) |
| loss = ( |
| 0.05 * loss1 + loss2 |
| ) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| with torch.no_grad(): |
| if epoch == 9: |
| |
| x_sample = 255.0 * rearrange( |
| img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" |
| ) |
| imgsave = Image.fromarray(x_sample.astype(np.uint8)) |
| imgsave.save(image_save_path + "/final.png") |
| torch.save( |
| weighting_parameter, |
| lambda_save_path + "/weightingParam_final.pt", |
| ) |
| torch.cuda.empty_cache() |
| return None |
|
|
|
|
| |
|
|