Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import os | |
| import pickle | |
| from ldm.util import default | |
| import glob | |
| import PIL | |
| import matplotlib.pyplot as plt | |
| def load_file(filename): | |
| with open(filename , 'rb') as file: | |
| x = pickle.load(file) | |
| return x | |
| def save_file(filename, x, mode="wb"): | |
| with open(filename, mode) as file: | |
| pickle.dump(x, file) | |
| def normalize_np(img): | |
| """ Normalize img in arbitrary range to [0, 1] """ | |
| img -= np.min(img) | |
| img /= np.max(img) | |
| return img | |
| def clear_color(x): | |
| if torch.is_complex(x): | |
| x = torch.abs(x) | |
| x = x.detach().cpu().squeeze().numpy() | |
| return normalize_np(np.transpose(x, (1, 2, 0))) | |
| def to_img(sample): | |
| return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255) | |
| def save_plot(dir_name, tensors, labels, file_name="loss.png"): | |
| t = np.linspace(0, len(tensors[0]), len(tensors[0])) | |
| colours = ["r", "b", "g"] | |
| plt.figure() | |
| for j in range(len(tensors)): | |
| plt.plot(t, tensors[j],color = colours[j], label = labels[j]) | |
| plt.legend() | |
| plt.savefig(os.path.join(dir_name, file_name)) | |
| #plt.show() | |
| def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None): | |
| if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8) | |
| else: sample_np = sample.astype(np.uint8) | |
| for j in range(num_to_save): | |
| if file_name is None: | |
| if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png' | |
| else: file_name_img = f'{j}.png' | |
| else: file_name_img = file_name | |
| image_path = os.path.join(dir_name,file_name_img) | |
| image_np = sample_np[j] | |
| PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
| file_name_img = None | |
| def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None): | |
| recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample | |
| recon_in = to_img(recon_in) | |
| for j in range(num_to_save): | |
| if file_name is None: | |
| if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png' | |
| else: file_name_img = f'{j}.png' | |
| else: file_name_img = file_name | |
| image_path = os.path.join(dir_name, file_name_img) | |
| image_np = recon_in.astype(np.uint8)[j] | |
| PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
| file_name_img = None | |
| def save_params(dir_name, mu_pos, logvar_pos, gamma,k): | |
| params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()]) | |
| params_path = os.path.join(dir_name, f'{k+1}.pt') | |
| torch.save(params_to_fit, params_path) | |
| def custom_to_np(img): | |
| sample = img.detach().cpu() | |
| #sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
| #sample = sample.permute(0, 2, 3, 1) | |
| sample = sample.contiguous() | |
| return sample | |
| def encoder_kl(diff, img): | |
| _, params = diff.encode_first_stage(img, return_all = True) | |
| params = diff.scale_factor * params | |
| mean, logvar = torch.chunk(params, 2, dim=1) | |
| noise = default(None, lambda: torch.randn_like(mean)) | |
| mean = mean + diff.scale_factor*noise | |
| return mean, logvar | |
| def encoder_vq(diff, img): | |
| quant = diff.encode_first_stage(img) #, diff, (_,_,ind) | |
| quant = diff.scale_factor * quant | |
| #mean, logvar = torch.chunk(params, 2, dim=1) | |
| noise = default(None, lambda: torch.randn_like(quant)) | |
| mean = quant + diff.scale_factor*noise # | |
| return mean | |
| def clean_directory(dir_name): | |
| files = glob.glob(dir_name) | |
| for f in files: | |
| os.remove(f) | |
| def params_train( params ): | |
| for item in params: | |
| item.requires_grad = True | |
| return params | |
| def params_untrain(params): | |
| for item in params: | |
| item.requires_grad = False | |
| return params | |
| def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18): | |
| step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda() | |
| t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho | |
| inv_idx = torch.arange(num_t_steps -1, -1, -1).long() | |
| t_steps_fwd = t_steps[inv_idx] | |
| #t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 | |
| return t_steps_fwd | |
| def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) : | |
| [lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99] | |
| optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99)) | |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) | |
| optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun | |
| optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01 | |
| scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this | |
| scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma) | |
| return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3] | |
| def check_directory(filename_list): | |
| for filename in filename_list: | |
| if not os.path.exists(filename): | |
| os.mkdir(filename) | |
| def s_file(filename, x, mode="wb"): | |
| with open(filename, mode) as file: | |
| pickle.dump(x, file) | |
| def r_file(filename, mode="rb"): | |
| with open(filename, mode) as file: | |
| x = pickle.load(file) | |
| return x | |
| def sample_from_gaussian(mu, alpha, sigma): | |
| noise = torch.randn_like(mu) | |
| return alpha*mu + sigma * noise | |
| ''' | |
| def make_batch(image, mask=None, device=None): | |
| image = torch.permute(image, (0,3,1,2)) | |
| batch_size = image.shape[0] | |
| if mask is None : | |
| mask = torch.zeros_like(image) | |
| mask[0, :, :256, :128] = 1 | |
| else : | |
| mask = torch.tensor(mask) | |
| masked_image = (mask)*image #+ mask*noise*0.2 | |
| mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3]) | |
| batch = {"image": image, "mask": mask, "masked_image": masked_image} | |
| for k in batch: | |
| batch[k] = batch[k].to(device) | |
| return batch | |
| def get_sigma_t_steps(net, n_steps=3, kwargs=None): | |
| sigma_min = kwargs["sigma_min"] | |
| sigma_max = kwargs["sigma_max"] | |
| sigma_min = max(sigma_min, net.sigma_min) | |
| sigma_max = min(sigma_max, net.sigma_max) | |
| ##Get the time-steps based on iddpm discretization | |
| num_steps = n_steps #11 # kwargs["num_steps"] | |
| C_2 = kwargs["C_2"] | |
| C_1 = kwargs["C_1"] | |
| M = kwargs["M"] | |
| step_indices = torch.arange(num_steps, dtype=torch.float64).cuda() | |
| u = torch.zeros(M + 1, dtype=torch.float64).cuda() | |
| alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 | |
| for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1 | |
| u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() | |
| u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] | |
| sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] | |
| #print(sigma_steps) | |
| ##get noise schedule | |
| sigma = lambda t: t | |
| sigma_deriv = lambda t: 1 | |
| sigma_inv = lambda sigma: sigma | |
| ##scaling schedule | |
| s = lambda t: 1 | |
| s_deriv = lambda t: 0 | |
| ##compute some final time steps based on the corresponding noise levels. | |
| t_steps = sigma_inv(net.round_sigma(sigma_steps)) | |
| return t_steps, sigma_inv, sigma, s, sigma_deriv | |
| def data_replicate(data, K): | |
| if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1]) | |
| else: data_batch = torch.Tensor.repeat(data,[K,1,1,1]) | |
| return data_batch | |
| ''' | |
| def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None): | |
| ''' | |
| sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000 | |
| T_max = 1000 | |
| beta_start = 1 # 0.0015*T_max | |
| beta_end = 15 # 0.0155*T_max | |
| def var(t): | |
| return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t) | |
| ''' | |
| t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda() | |
| var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0]) | |
| x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0) | |
| os.makedirs("out_temp2/", exist_ok=True) | |
| for i, t in enumerate(t_steps_hierarchy): | |
| t_hat = torch.ones(10).cuda() * (t) | |
| e_out = self.model.model(x_t, t_hat) | |
| var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2 | |
| #score_out = - e_out / torch.sqrt() | |
| a_t = 1 - var_t | |
| #beta_t = 1 - a_t/a_prev | |
| #std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t) | |
| pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt() | |
| if i != len(t_steps_hierarchy) - 1: | |
| var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2 | |
| a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda() | |
| sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)) | |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out | |
| x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t) | |
| #x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t) | |
| ''' | |
| def pred_mean(pred_x0, z_t): | |
| posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t) | |
| posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t) | |
| return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t | |
| x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t) | |
| ''' | |
| recon = self.model.decode_first_stage(pred_x0) | |
| image_path = os.path.join("out_temp2/", f'{i}.png') | |
| image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] | |
| PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
| return | |