| """SAMPLING ONLY.""" |
|
|
| from confs import * |
| import torch |
| import numpy as np |
| from tqdm import tqdm |
| from functools import partial |
| from src.Face_models.encoders.model_irse import Backbone |
| import torch.nn as nn |
| import torchvision.transforms.functional as TF |
|
|
| from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ |
| extract_into_tensor |
|
|
| def un_norm_clip(x1): |
| x = x1*1.0 |
| reduce=False |
| if len(x.shape)==3: |
| x = x.unsqueeze(0) |
| reduce=True |
| x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466 |
| x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275 |
| x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073 |
| |
| if reduce: |
| x = x.squeeze(0) |
| return x |
|
|
| class IDLoss(nn.Module): |
| def __init__(self,path="Other_dependencies/arcface/model_ir_se50.pth",multiscale=False): |
| super(IDLoss, self).__init__() |
| print('Loading ResNet ArcFace') |
| |
| self.multiscale = multiscale |
| self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256)) |
| self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') |
| |
| |
| self.facenet.load_state_dict(torch.load(path)) |
| |
| self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112)) |
| self.facenet.eval() |
| |
| self.set_requires_grad(False) |
| |
| def set_requires_grad(self, flag=True): |
| for p in self.parameters(): |
| p.requires_grad = flag |
| |
| def extract_feats(self, x,clip_img=True): |
| |
| if clip_img: |
| x = un_norm_clip(x) |
| x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| x = self.face_pool_1(x) if x.shape[2]!=256 else x |
| x = x[:, :, 35:223, 32:220] |
| x = self.face_pool_2(x) |
| |
| x_feats = self.facenet(x, multi_scale=self.multiscale ) |
| |
| |
| return x_feats |
|
|
| |
|
|
| def forward(self, y_hat, y,clip_img=True,return_seperate=False): |
| n_samples = y.shape[0] |
| y_feats_ms = self.extract_feats(y,clip_img=clip_img) |
|
|
| y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img) |
| y_feats_ms = [y_f.detach() for y_f in y_feats_ms] |
| |
| loss_all = 0 |
| sim_improvement_all = 0 |
| seperate_losses=[] |
| for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms): |
| |
| loss = 0 |
| sim_improvement = 0 |
| count = 0 |
| |
| for i in range(n_samples): |
| sim_target = y_hat_feats[i].dot(y_feats[i]) |
| sim_views = y_feats[i].dot(y_feats[i]) |
|
|
| seperate_losses.append(1-sim_target) |
| loss += 1 - sim_target |
| sim_improvement += float(sim_target) - float(sim_views) |
| count += 1 |
| |
| loss_all += loss / count |
| sim_improvement_all += sim_improvement / count |
| |
| return loss_all, sim_improvement_all, None |
| |
|
|
| class DDIMSampler(object): |
| def __init__(self, model, schedule="linear", **kwargs): |
| super().__init__() |
| self.model = model |
| self.ddpm_num_timesteps = model.num_timesteps |
| self.schedule = schedule |
| |
|
|
| def register_buffer(self, name, attr): |
| if type(attr) == torch.Tensor: |
| if attr.device != torch.device("cuda"): |
| attr = attr.to(torch.device("cuda")) |
| setattr(self, name, attr) |
|
|
| def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): |
| self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, |
| num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) |
| alphas_cumprod = self.model.alphas_cumprod |
| assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' |
| to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) |
|
|
| self.register_buffer('betas', to_torch(self.model.betas)) |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) |
| self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) |
|
|
| |
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) |
| self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) |
| self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) |
|
|
| |
| ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), |
| ddim_timesteps=self.ddim_timesteps, |
| eta=ddim_eta,verbose=verbose) |
| self.register_buffer('ddim_sigmas', ddim_sigmas) |
| self.register_buffer('ddim_alphas', ddim_alphas) |
| self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) |
| self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) |
| sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( |
| (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( |
| 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) |
| self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) |
|
|
| @torch.no_grad() |
| def sample(self, |
| S, |
| batch_size, |
| shape, |
| conditioning=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0., |
| mask=None, |
| x0=None, |
| temperature=1., |
| noise_dropout=0., |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1., |
| unconditional_conditioning=None, |
| |
| **kwargs |
| ): |
| if conditioning is not None: |
| if isinstance(conditioning, dict): |
| cbs = conditioning[list(conditioning.keys())[0]].shape[0] |
| if cbs != batch_size: |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") |
| else: |
| if conditioning.shape[0] != batch_size: |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| |
|
|
| samples, intermediates = self.ddim_sampling(conditioning, size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| **kwargs |
| ) |
| return samples, intermediates |
|
|
| @torch.no_grad() |
| def ddim_sampling(self, cond, shape, |
| x_T=None, ddim_use_original_steps=False, |
| callback=None, timesteps=None, quantize_denoised=False, |
| mask=None, x0=None, img_callback=None, log_every_t=100, |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| unconditional_guidance_scale=1., unconditional_conditioning=None, |
| z_ref=None, |
| **kwargs): |
| device = self.model.betas.device |
| |
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
| if z_ref is not None: |
| tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) |
| if REFNET.CH9: |
| z_ref = torch.cat([z_ref, z_ref, tensor_1c], dim=1) |
|
|
| if timesteps is None: |
| timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| intermediates = {'x_inter': [img], 'pred_x0': [img]} |
| time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| print(f"Running DDIM Sampling with {total_steps} timesteps") |
|
|
| iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) |
| for i, step in enumerate(iterator): |
| index = total_steps - i - 1 |
| ts = torch.full((b,), step, device=device, dtype=torch.long) |
|
|
| if mask is not None: |
| assert x0 is not None |
| img_orig = self.model.q_sample(x0, ts) |
| img = img_orig * mask + (1. - mask) * img |
| if z_ref is not None: |
| z_ref_noisy = self.model.q_sample(x_start=z_ref[:,:4], t=ts, ) |
| if REFNET.CH9: |
| z_ref[:,:4] = z_ref_noisy |
| |
| outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, |
| z_ref=z_ref, |
| quantize_denoised=quantize_denoised, temperature=temperature, |
| noise_dropout=noise_dropout, score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning,**kwargs) |
| img, pred_x0 = outs |
| if callback: callback(i) |
| if img_callback: img_callback(pred_x0, i) |
|
|
| if index % log_every_t == 0 or index == total_steps - 1: |
| intermediates['x_inter'].append(img) |
| intermediates['pred_x0'].append(pred_x0) |
|
|
| return img, intermediates |
|
|
| |
| @torch.no_grad() |
| def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| unconditional_guidance_scale=1., unconditional_conditioning=None, |
| z_ref=None, |
| **kwargs): |
| """ |
| 0. input param is: (x, c, t, [z_ref] ) |
| 1. x=concat(x,inpaint,mask) |
| 2. apply_model(x, t, c, [z_ref] ) |
| ( similar to ddpm.py LatentDiffusion.p_losses() |
| """ |
| b, *_, device = *x.shape, x.device |
| if 1: |
| z_inpaint = kwargs['z_inpaint'] |
| z_inpaint_mask = kwargs['z_inpaint_mask'] |
| z9 = kwargs['z9'] |
| |
| x = torch.cat([x, z9[:,4:] ],dim=1) |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: |
| e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) |
| else: |
| if MERGE_CFG_in_one_batch: |
| |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| if z_ref is not None: |
| z_ref_in = torch.cat([z_ref] * 2) |
| else: |
| z_ref_in = None |
| c_in = torch.cat([unconditional_conditioning, c]) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in,).chunk(2) |
| else: |
| |
| e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, z_ref=z_ref,) |
| e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
|
|
| if score_corrector is not None: |
| assert self.model.parameterization == "eps" |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) |
|
|
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev |
| sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas |
| sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) |
|
|
| |
| if x.shape[1]!=4: |
| pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| else: |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| if quantize_denoised: |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t |
| noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature |
| if noise_dropout > 0.: |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| return x_prev, pred_x0 |
| |
| |
| def sample_train(self, |
| S, |
| batch_size, |
| shape, |
| conditioning=None, |
| callback=None, |
| normals_sequence=None, |
| img_callback=None, |
| quantize_x0=False, |
| eta=0., |
| mask=None, |
| x0=None, |
| temperature=1., |
| noise_dropout=0., |
| t=None, |
| score_corrector=None, |
| corrector_kwargs=None, |
| verbose=True, |
| x_T=None, |
| log_every_t=100, |
| unconditional_guidance_scale=1., |
| unconditional_conditioning=None, |
| |
| **kwargs |
| ): |
| if conditioning is not None: |
| if isinstance(conditioning, dict): |
| cbs = conditioning[list(conditioning.keys())[0]].shape[0] |
| if cbs != batch_size: |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") |
| else: |
| if conditioning.shape[0] != batch_size: |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") |
|
|
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| |
| C, H, W = shape |
| size = (batch_size, C, H, W) |
| |
| |
| |
| samples, intermediates = self.ddim_sampling_train(conditioning, size, |
| callback=callback, |
| img_callback=img_callback, |
| quantize_denoised=quantize_x0, |
| mask=mask, x0=x0, |
| ddim_use_original_steps=False, |
| noise_dropout=noise_dropout, |
| temperature=temperature, |
| score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| x_T=x_T,ddim_num_steps=S, |
| curr_t=t, |
| log_every_t=log_every_t, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning, |
| **kwargs |
| ) |
| return samples, intermediates |
|
|
| |
| def ddim_sampling_train(self, cond, shape, |
| x_T=None, ddim_use_original_steps=False,ddim_num_steps=None, |
| callback=None, timesteps=None, quantize_denoised=False, |
| mask=None, x0=None, img_callback=None, log_every_t=100,curr_t=None, |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): |
| device = self.model.betas.device |
| b = shape[0] |
| if x_T is None: |
| img = torch.randn(shape, device=device) |
| else: |
| img = x_T |
| |
| kwargs['rest']=img[:,4:,:,:] |
| img=img[:,:4,:,:] |
| |
|
|
| if timesteps is None: |
| timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps |
| elif timesteps is not None and not ddim_use_original_steps: |
| subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 |
| timesteps = self.ddim_timesteps[:subset_end] |
|
|
| curr_t=curr_t.cpu().numpy() |
| skip = (curr_t-1) // ddim_num_steps |
| |
| skip[skip == 0] = 1 |
| if type(skip)!=int: |
| seq=[range(1, curr_t[n]-1, skip[n]) for n in range(len(curr_t))] |
| min_length = min(len(sublist) for sublist in seq) |
| min_length=min(min_length,ddim_num_steps) |
| |
| truncated_seq = [sublist[:min_length] for sublist in seq] |
| seq= np.array(truncated_seq) |
|
|
| |
| |
| |
| seq=torch.from_numpy(seq).to(device) |
| seq=torch.flip(seq,dims=[1]) |
|
|
| |
| |
| intermediates = {'x_inter': [img], 'pred_x0': [img]} |
| intermediates = {'x_inter': [], 'pred_x0': []} |
| |
| |
| |
|
|
|
|
| |
| |
| |
| total_steps=seq.shape[1] |
| for i in range(seq.shape[1]): |
| index = total_steps - i - 1 |
| |
| ts=seq[:,i].long() |
| |
| |
|
|
| if mask is not None: |
| assert x0 is not None |
| img_orig = self.model.q_sample(x0, ts) |
| img = img_orig * mask + (1. - mask) * img |
| outs = self.p_sample_ddim_train(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, |
| quantize_denoised=quantize_denoised, temperature=temperature, |
| noise_dropout=noise_dropout, score_corrector=score_corrector, |
| corrector_kwargs=corrector_kwargs, |
| unconditional_guidance_scale=unconditional_guidance_scale, |
| unconditional_conditioning=unconditional_conditioning,**kwargs) |
| img, pred_x0 = outs |
| if callback: callback(i) |
| if img_callback: img_callback(pred_x0, i) |
|
|
| |
| if i in [ total_steps - 1, ]: |
| |
| intermediates['x_inter'].append(img) |
| intermediates['pred_x0'].append(pred_x0) |
|
|
| return img, intermediates |
|
|
| |
| def p_sample_ddim_train(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| unconditional_guidance_scale=1., unconditional_conditioning=None,return_features=False,**kwargs): |
| b, *_, device = *x.shape, x.device |
| |
| |
| |
| if 'rest' in kwargs: |
| x = torch.cat((x, kwargs['rest']), dim=1) |
| |
| |
| z_ref = kwargs.pop('z_ref',None) |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: |
| e_t = self.model.apply_model(x, t, c,return_features=return_features,z_ref=z_ref) |
| else: |
| assert 0 |
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
| c_in = torch.cat([unconditional_conditioning, c]) |
| if return_features: |
| e_t_uncond, e_t,features = self.model.apply_model(x_in, t_in, c_in,return_features=return_features).chunk(3) |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
|
|
| if score_corrector is not None: |
| assert self.model.parameterization == "eps" |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) |
|
|
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev |
| sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas |
| sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas |
| |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) |
|
|
| |
| if x.shape[1]!=4: |
| pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| else: |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| if quantize_denoised: |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t |
| noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature |
| if noise_dropout > 0.: |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| return x_prev, pred_x0 |
|
|
| @torch.no_grad() |
| def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, |
| use_original_steps=False): |
| assert 0 |
|
|