"""SAMPLING ONLY.""" from confs import * import torch import numpy as np from tqdm import tqdm from functools import partial from src.Face_models.encoders.model_irse import Backbone import torch.nn as nn import torchvision.transforms.functional as TF from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ extract_into_tensor def un_norm_clip(x1): x = x1*1.0 # to avoid changing the original tensor or clone() can be used reduce=False if len(x.shape)==3: x = x.unsqueeze(0) reduce=True x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466 x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275 x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073 if reduce: x = x.squeeze(0) return x class IDLoss(nn.Module): def __init__(self,path="Other_dependencies/arcface/model_ir_se50.pth",multiscale=False): super(IDLoss, self).__init__() print('Loading ResNet ArcFace') self.multiscale = multiscale self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256)) self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') # self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan self.facenet.load_state_dict(torch.load(path)) self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet.eval() self.set_requires_grad(False) def set_requires_grad(self, flag=True): for p in self.parameters(): p.requires_grad = flag def extract_feats(self, x,clip_img=True): # breakpoint() if clip_img: x = un_norm_clip(x) x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed x = x[:, :, 35:223, 32:220] # (2) Crop interesting region x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model # breakpoint() x_feats = self.facenet(x, multi_scale=self.multiscale ) # x_feats = self.facenet(x) # changed by sanoojan return x_feats def forward(self, y_hat, y,clip_img=True,return_seperate=False): n_samples = y.shape[0] y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img) y_feats_ms = [y_f.detach() for y_f in y_feats_ms] loss_all = 0 sim_improvement_all = 0 seperate_losses=[] for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms): loss = 0 sim_improvement = 0 count = 0 for i in range(n_samples): sim_target = y_hat_feats[i].dot(y_feats[i]) sim_views = y_feats[i].dot(y_feats[i]) seperate_losses.append(1-sim_target) loss += 1 - sim_target # id loss sim_improvement += float(sim_target) - float(sim_views) count += 1 loss_all += loss / count sim_improvement_all += sim_improvement / count return loss_all, sim_improvement_all, None class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule # self.ID_LOSS=IDLoss() def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta,verbose=verbose) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) @torch.no_grad() def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) # print(f'Data shape for DDIM sampling is {size}, eta {eta}') samples, intermediates = self.ddim_sampling(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, **kwargs ) return samples, intermediates @torch.no_grad() def ddim_sampling(self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None, **kwargs): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if z_ref is not None: tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) if REFNET.CH9: z_ref = torch.cat([z_ref, z_ref, tensor_1c], dim=1) if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps elif timesteps is not None and not ddim_use_original_steps: subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) if mask is not None: # None assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img if z_ref is not None: z_ref_noisy = self.model.q_sample(x_start=z_ref[:,:4], t=ts, ) if REFNET.CH9: z_ref[:,:4] = z_ref_noisy # img and pred_x0 both B,4,64,64; cond/unconditional_conditioning tensors are B,1,768 outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, z_ref=z_ref, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning,**kwargs) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) return img, intermediates @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None, **kwargs): """ 0. input param is: (x, c, t, [z_ref] ) 1. x=concat(x,inpaint,mask) 2. apply_model(x, t, c, [z_ref] ) ( similar to ddpm.py LatentDiffusion.p_losses() """ b, *_, device = *x.shape, x.device if 1: z_inpaint = kwargs['z_inpaint'] # B,4 z_inpaint_mask = kwargs['z_inpaint_mask'] # B,1 z9 = kwargs['z9'] # B,9or14 # x = torch.cat([x, z_inpaint, z_inpaint_mask],dim=1) # B,9,... x = torch.cat([x, z9[:,4:] ],dim=1) # B,9or14,... if unconditional_conditioning is None or unconditional_guidance_scale == 1.: e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) else: # check @ sanoojan if MERGE_CFG_in_one_batch: # b,... -> 2b,... x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 t_in = torch.cat([t] * 2) if z_ref is not None: z_ref_in = torch.cat([z_ref] * 2) else: z_ref_in = None c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in,).chunk(2) else: # first infer unconditional then conditional (reduces peak CUDA memory) e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, z_ref=z_ref,) e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 if score_corrector is not None: assert self.model.parameterization == "eps" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 if x.shape[1]!=4: pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() else: pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 def sample_train(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., t=None, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) # print(f'Data shape for DDIM sampling is {size}, eta {eta}') # for param in self.model.first_stage_model.parameters(): # param.requires_grad = False samples, intermediates = self.ddim_sampling_train(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T,ddim_num_steps=S, curr_t=t, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, **kwargs ) return samples, intermediates def ddim_sampling_train(self, cond, shape, x_T=None, ddim_use_original_steps=False,ddim_num_steps=None, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100,curr_t=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T kwargs['rest']=img[:,4:,:,:] img=img[:,:4,:,:] if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps elif timesteps is not None and not ddim_use_original_steps: subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] curr_t=curr_t.cpu().numpy() skip = (curr_t-1) // ddim_num_steps # replace all 0s with 1s skip[skip == 0] = 1 if type(skip)!=int: seq=[range(1, curr_t[n]-1, skip[n]) for n in range(len(curr_t))] min_length = min(len(sublist) for sublist in seq) min_length=min(min_length,ddim_num_steps) # Create a new list of sublists by truncating each sublist to the minimum length truncated_seq = [sublist[:min_length] for sublist in seq] seq= np.array(truncated_seq) # seq=np.flip(seq) #concatenate all sequences # seq = np.concatenate(seq) seq=torch.from_numpy(seq).to(device) seq=torch.flip(seq,dims=[1]) intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {'x_inter': [], 'pred_x0': []} # time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) # total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] # print(f"Running DDIM Sampling with {total_steps} timesteps") # time_range=np.array([1]) # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) total_steps=seq.shape[1] # 4 (ddim 4 steps) for i in range(seq.shape[1]): index = total_steps - i - 1 # ts = torch.full((b,), step, device=device, dtype=torch.long) ts=seq[:,i].long() #make it toech long # ts=ts.long() if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img outs = self.p_sample_ddim_train(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning,**kwargs) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) # if index % log_every_t == 0 or index == total_steps - 1: if i in [ total_steps - 1, ]: # if 1: # len_inter 4 (5 if orig rf) => OOM intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) return img, intermediates def p_sample_ddim_train(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None,return_features=False,**kwargs): b, *_, device = *x.shape, x.device # if 'test_model_kwargs' in kwargs: # kwargs=kwargs['test_model_kwargs'] # x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1) if 'rest' in kwargs: x = torch.cat((x, kwargs['rest']), dim=1) z_ref = kwargs.pop('z_ref',None) if unconditional_conditioning is None or unconditional_guidance_scale == 1.: e_t = self.model.apply_model(x, t, c,return_features=return_features,z_ref=z_ref) else: # check @ sanoojan assert 0 x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 if return_features: e_t_uncond, e_t,features = self.model.apply_model(x_in, t_in, c_in,return_features=return_features).chunk(3) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 if score_corrector is not None: assert self.model.parameterization == "eps" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 if x.shape[1]!=4: pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() else: pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_original_steps=False): assert 0