Spaces:
Runtime error
Runtime error
| """SAMPLING ONLY.""" | |
| from imports import * | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from functools import partial | |
| from src.Face_models.encoders.model_irse import Backbone | |
| import torch.nn as nn | |
| import torchvision.transforms.functional as TF | |
| from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ | |
| extract_into_tensor | |
| def un_norm_clip(x1): | |
| x = x1*1.0 # to avoid changing the original tensor or clone() can be used | |
| reduce=False | |
| if len(x.shape)==3: | |
| x = x.unsqueeze(0) | |
| reduce=True | |
| x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466 | |
| x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275 | |
| x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073 | |
| if reduce: | |
| x = x.squeeze(0) | |
| return x | |
| class IDLoss(nn.Module): | |
| def __init__(self,path="Other_dependencies/arcface/model_ir_se50.pth",multiscale=False): | |
| super(IDLoss, self).__init__() | |
| print('Loading ResNet ArcFace') | |
| self.multiscale = multiscale | |
| self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | |
| # self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan | |
| self.facenet.load_state_dict(torch.load(path)) | |
| self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112)) | |
| self.facenet.eval() | |
| self.set_requires_grad(False) | |
| def set_requires_grad(self, flag=True): | |
| for p in self.parameters(): | |
| p.requires_grad = flag | |
| def extract_feats(self, x,clip_img=True): | |
| # breakpoint() | |
| if clip_img: | |
| x = un_norm_clip(x) | |
| x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed | |
| x = x[:, :, 35:223, 32:220] # (2) Crop interesting region | |
| x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model | |
| # breakpoint() | |
| x_feats = self.facenet(x, multi_scale=self.multiscale ) | |
| # x_feats = self.facenet(x) # changed by sanoojan | |
| return x_feats | |
| def forward(self, y_hat, y,clip_img=True,return_seperate=False): | |
| n_samples = y.shape[0] | |
| y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there | |
| y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img) | |
| y_feats_ms = [y_f.detach() for y_f in y_feats_ms] | |
| loss_all = 0 | |
| sim_improvement_all = 0 | |
| seperate_losses=[] | |
| for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms): | |
| loss = 0 | |
| sim_improvement = 0 | |
| count = 0 | |
| for i in range(n_samples): | |
| sim_target = y_hat_feats[i].dot(y_feats[i]) | |
| sim_views = y_feats[i].dot(y_feats[i]) | |
| seperate_losses.append(1-sim_target) | |
| loss += 1 - sim_target # id loss | |
| sim_improvement += float(sim_target) - float(sim_views) | |
| count += 1 | |
| loss_all += loss / count | |
| sim_improvement_all += sim_improvement / count | |
| return loss_all, sim_improvement_all, None | |
| class DDIMSampler(object): | |
| def __init__(self, model, schedule="linear", **kwargs): | |
| super().__init__() | |
| self.model = model | |
| self.ddpm_num_timesteps = model.num_timesteps | |
| self.schedule = schedule | |
| # self.ID_LOSS=IDLoss() | |
| def register_buffer(self, name, attr): | |
| if type(attr) == torch.Tensor: | |
| if attr.device != torch.device("cuda"): | |
| attr = attr.to(torch.device("cuda")) | |
| setattr(self, name, attr) | |
| def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): | |
| self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, | |
| num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) | |
| alphas_cumprod = self.model.alphas_cumprod | |
| assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' | |
| to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) | |
| self.register_buffer('betas', to_torch(self.model.betas)) | |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
| self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) | |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) | |
| self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) | |
| self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) | |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) | |
| # ddim sampling parameters | |
| ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), | |
| ddim_timesteps=self.ddim_timesteps, | |
| eta=ddim_eta,verbose=verbose) | |
| self.register_buffer('ddim_sigmas', ddim_sigmas) | |
| self.register_buffer('ddim_alphas', ddim_alphas) | |
| self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) | |
| self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) | |
| sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( | |
| (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( | |
| 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) | |
| self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) | |
| def sample(self, | |
| S, | |
| batch_size, | |
| shape, | |
| conditioning=None, | |
| callback=None, | |
| normals_sequence=None, | |
| img_callback=None, | |
| quantize_x0=False, | |
| eta=0., | |
| mask=None, | |
| x0=None, | |
| temperature=1., | |
| noise_dropout=0., | |
| score_corrector=None, | |
| corrector_kwargs=None, | |
| verbose=True, | |
| x_T=None, | |
| log_every_t=100, | |
| unconditional_guidance_scale=1., | |
| unconditional_conditioning=None, | |
| # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
| **kwargs | |
| ): | |
| if conditioning is not None: | |
| if isinstance(conditioning, dict): | |
| cbs = conditioning[list(conditioning.keys())[0]].shape[0] | |
| if cbs != batch_size: | |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | |
| else: | |
| if conditioning.shape[0] != batch_size: | |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) | |
| # sampling | |
| C, H, W = shape | |
| size = (batch_size, C, H, W) | |
| # print(f'Data shape for DDIM sampling is {size}, eta {eta}') | |
| samples, intermediates = self.ddim_sampling(conditioning, size, | |
| callback=callback, | |
| img_callback=img_callback, | |
| quantize_denoised=quantize_x0, | |
| mask=mask, x0=x0, | |
| ddim_use_original_steps=False, | |
| noise_dropout=noise_dropout, | |
| temperature=temperature, | |
| score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| x_T=x_T, | |
| log_every_t=log_every_t, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning, | |
| **kwargs | |
| ) | |
| return samples, intermediates | |
| def ddim_sampling(self, cond, shape, | |
| x_T=None, ddim_use_original_steps=False, | |
| callback=None, timesteps=None, quantize_denoised=False, | |
| mask=None, x0=None, img_callback=None, log_every_t=100, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None, | |
| z_ref=None, | |
| **kwargs): | |
| device = self.model.betas.device | |
| b = shape[0] | |
| if x_T is None: | |
| img = torch.randn(shape, device=device) | |
| else: | |
| img = x_T | |
| if z_ref is not None: | |
| tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) | |
| if REFNET.CH9: | |
| z_ref = torch.cat([z_ref, z_ref, tensor_1c], dim=1) | |
| if timesteps is None: | |
| timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps | |
| elif timesteps is not None and not ddim_use_original_steps: | |
| subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 | |
| timesteps = self.ddim_timesteps[:subset_end] | |
| intermediates = {'x_inter': [img], 'pred_x0': [img]} | |
| time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) | |
| total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] | |
| print(f"Running DDIM Sampling with {total_steps} timesteps") | |
| iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) | |
| for i, step in enumerate(iterator): | |
| index = total_steps - i - 1 | |
| ts = torch.full((b,), step, device=device, dtype=torch.long) | |
| if mask is not None: # None | |
| assert x0 is not None | |
| img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? | |
| img = img_orig * mask + (1. - mask) * img | |
| if z_ref is not None: | |
| z_ref_noisy = self.model.q_sample(x_start=z_ref[:,:4], t=ts, ) | |
| if REFNET.CH9: | |
| z_ref[:,:4] = z_ref_noisy | |
| # img and pred_x0 both B,4,64,64; cond/unconditional_conditioning tensors are B,1,768 | |
| outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, | |
| z_ref=z_ref, | |
| quantize_denoised=quantize_denoised, temperature=temperature, | |
| noise_dropout=noise_dropout, score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning,**kwargs) | |
| img, pred_x0 = outs | |
| if callback: callback(i) | |
| if img_callback: img_callback(pred_x0, i) | |
| if index % log_every_t == 0 or index == total_steps - 1: | |
| intermediates['x_inter'].append(img) | |
| intermediates['pred_x0'].append(pred_x0) | |
| return img, intermediates | |
| def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None, | |
| z_ref=None, | |
| **kwargs): | |
| """ | |
| 0. input param is: (x, c, t, [z_ref] ) | |
| 1. x=concat(x,inpaint,mask) | |
| 2. apply_model(x, t, c, [z_ref] ) | |
| ( similar to ddpm.py LatentDiffusion.p_losses() | |
| """ | |
| b, *_, device = *x.shape, x.device | |
| if 1: | |
| z_inpaint = kwargs['z_inpaint'] # B,4 | |
| z_inpaint_mask = kwargs['z_inpaint_mask'] # B,1 | |
| z9 = kwargs['z9'] # B,9or14 | |
| # x = torch.cat([x, z_inpaint, z_inpaint_mask],dim=1) # B,9,... | |
| x = torch.cat([x, z9[:,4:] ],dim=1) # B,9or14,... | |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
| e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) | |
| else: # check @ sanoojan | |
| if MERGE_CFG_in_one_batch: | |
| # b,... -> 2b,... | |
| x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 | |
| t_in = torch.cat([t] * 2) | |
| if z_ref is not None: | |
| z_ref_in = torch.cat([z_ref] * 2) | |
| else: | |
| z_ref_in = None | |
| batch_size = t.shape[0] | |
| double_gg_lmk = batch_size>1 and hasattr(global_, 'lmk_') and global_.lmk_ is not None | |
| if double_gg_lmk: | |
| orig_lmk_ = global_.lmk_ | |
| global_.lmk_ = torch.cat([orig_lmk_] * 2) | |
| c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 | |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in,).chunk(2) | |
| if double_gg_lmk: | |
| global_.lmk_ = orig_lmk_ | |
| else: | |
| # first infer unconditional then conditional (reduces peak CUDA memory) | |
| e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, z_ref=z_ref,) | |
| e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) | |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 | |
| if score_corrector is not None: | |
| assert self.model.parameterization == "eps" | |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas | |
| alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev | |
| sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas | |
| sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas | |
| # select parameters corresponding to the currently considered timestep | |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
| # current prediction for x_0 | |
| if x.shape[1]!=4: | |
| pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| else: | |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| if quantize_denoised: | |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) | |
| # direction pointing to x_t | |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
| noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature | |
| if noise_dropout > 0.: | |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
| return x_prev, pred_x0 | |
| def sample_train(self, | |
| S, | |
| batch_size, | |
| shape, | |
| conditioning=None, | |
| callback=None, | |
| normals_sequence=None, | |
| img_callback=None, | |
| quantize_x0=False, | |
| eta=0., | |
| mask=None, | |
| x0=None, | |
| temperature=1., | |
| noise_dropout=0., | |
| t=None, | |
| score_corrector=None, | |
| corrector_kwargs=None, | |
| verbose=True, | |
| x_T=None, | |
| log_every_t=100, | |
| unconditional_guidance_scale=1., | |
| unconditional_conditioning=None, | |
| # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
| **kwargs | |
| ): | |
| if conditioning is not None: | |
| if isinstance(conditioning, dict): | |
| cbs = conditioning[list(conditioning.keys())[0]].shape[0] | |
| if cbs != batch_size: | |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | |
| else: | |
| if conditioning.shape[0] != batch_size: | |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) | |
| # sampling | |
| C, H, W = shape | |
| size = (batch_size, C, H, W) | |
| # print(f'Data shape for DDIM sampling is {size}, eta {eta}') | |
| # for param in self.model.first_stage_model.parameters(): | |
| # param.requires_grad = False | |
| samples, intermediates = self.ddim_sampling_train(conditioning, size, | |
| callback=callback, | |
| img_callback=img_callback, | |
| quantize_denoised=quantize_x0, | |
| mask=mask, x0=x0, | |
| ddim_use_original_steps=False, | |
| noise_dropout=noise_dropout, | |
| temperature=temperature, | |
| score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| x_T=x_T,ddim_num_steps=S, | |
| curr_t=t, | |
| log_every_t=log_every_t, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning, | |
| **kwargs | |
| ) | |
| return samples, intermediates | |
| def ddim_sampling_train(self, cond, shape, | |
| x_T=None, ddim_use_original_steps=False,ddim_num_steps=None, | |
| callback=None, timesteps=None, quantize_denoised=False, | |
| mask=None, x0=None, img_callback=None, log_every_t=100,curr_t=None, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): | |
| device = self.model.betas.device | |
| b = shape[0] | |
| if x_T is None: | |
| img = torch.randn(shape, device=device) | |
| else: | |
| img = x_T | |
| kwargs['rest']=img[:,4:,:,:] | |
| img=img[:,:4,:,:] | |
| if timesteps is None: | |
| timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps | |
| elif timesteps is not None and not ddim_use_original_steps: | |
| subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 | |
| timesteps = self.ddim_timesteps[:subset_end] | |
| curr_t=curr_t.cpu().numpy() | |
| skip = (curr_t-1) // ddim_num_steps | |
| # replace all 0s with 1s | |
| skip[skip == 0] = 1 | |
| if type(skip)!=int: | |
| seq=[range(1, curr_t[n]-1, skip[n]) for n in range(len(curr_t))] | |
| min_length = min(len(sublist) for sublist in seq) | |
| min_length=min(min_length,ddim_num_steps) | |
| # Create a new list of sublists by truncating each sublist to the minimum length | |
| truncated_seq = [sublist[:min_length] for sublist in seq] | |
| seq= np.array(truncated_seq) | |
| # seq=np.flip(seq) | |
| #concatenate all sequences | |
| # seq = np.concatenate(seq) | |
| seq=torch.from_numpy(seq).to(device) | |
| seq=torch.flip(seq,dims=[1]) | |
| intermediates = {'x_inter': [img], 'pred_x0': [img]} | |
| intermediates = {'x_inter': [], 'pred_x0': []} | |
| # time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) | |
| # total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] | |
| # print(f"Running DDIM Sampling with {total_steps} timesteps") | |
| # time_range=np.array([1]) | |
| # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) | |
| total_steps=seq.shape[1] # 4 (ddim 4 steps) | |
| for i in range(seq.shape[1]): | |
| index = total_steps - i - 1 | |
| # ts = torch.full((b,), step, device=device, dtype=torch.long) | |
| ts=seq[:,i].long() | |
| #make it toech long | |
| # ts=ts.long() | |
| if mask is not None: | |
| assert x0 is not None | |
| img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? | |
| img = img_orig * mask + (1. - mask) * img | |
| outs = self.p_sample_ddim_train(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, | |
| quantize_denoised=quantize_denoised, temperature=temperature, | |
| noise_dropout=noise_dropout, score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning,**kwargs) | |
| img, pred_x0 = outs | |
| if callback: callback(i) | |
| if img_callback: img_callback(pred_x0, i) | |
| # if index % log_every_t == 0 or index == total_steps - 1: | |
| if i in [ total_steps - 1, ]: | |
| # if 1: # len_inter 4 (5 if orig rf) => OOM | |
| intermediates['x_inter'].append(img) | |
| intermediates['pred_x0'].append(pred_x0) | |
| return img, intermediates | |
| def p_sample_ddim_train(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None,return_features=False,**kwargs): | |
| b, *_, device = *x.shape, x.device | |
| # if 'test_model_kwargs' in kwargs: | |
| # kwargs=kwargs['test_model_kwargs'] | |
| # x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1) | |
| if 'rest' in kwargs: | |
| x = torch.cat((x, kwargs['rest']), dim=1) | |
| z_ref = kwargs.pop('z_ref',None) | |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
| e_t = self.model.apply_model(x, t, c,return_features=return_features,z_ref=z_ref) | |
| else: # check @ sanoojan | |
| assert 0 | |
| x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 | |
| t_in = torch.cat([t] * 2) | |
| c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 | |
| if return_features: | |
| e_t_uncond, e_t,features = self.model.apply_model(x_in, t_in, c_in,return_features=return_features).chunk(3) | |
| e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 | |
| if score_corrector is not None: | |
| assert self.model.parameterization == "eps" | |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
| alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas | |
| alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev | |
| sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas | |
| sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas | |
| # select parameters corresponding to the currently considered timestep | |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
| # current prediction for x_0 | |
| if x.shape[1]!=4: | |
| pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| else: | |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| if quantize_denoised: | |
| pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) | |
| # direction pointing to x_t | |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
| noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature | |
| if noise_dropout > 0.: | |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
| return x_prev, pred_x0 | |
| def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, | |
| use_original_steps=False): | |
| assert 0 | |