feizhengcong's picture
Upload 198 files
074c857
import torch
import os
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import numpy as np
from IPython import display
#
# Callback functions
#
class SamplerCallback(object):
# Creates the callback function to be passed into the samplers for each step
def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None,
verbose=False):
self.model = root.model
self.device = root.device
self.sampler_name = args.sampler
self.dynamic_threshold = args.dynamic_threshold
self.static_threshold = args.static_threshold
self.mask = mask
self.init_latent = init_latent
self.sigmas = sigmas
self.sampler = sampler
self.verbose = verbose
self.batch_size = args.n_samples
self.save_sample_per_step = args.save_sample_per_step
self.show_sample_per_step = args.show_sample_per_step
self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ]
if self.save_sample_per_step:
for path in self.paths_to_image_steps:
os.makedirs(path, exist_ok=True)
self.step_index = 0
self.noise = None
if init_latent is not None:
self.noise = torch.randn_like(init_latent, device=self.device)
self.mask_schedule = None
if sigmas is not None and len(sigmas) > 0:
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
elif len(sigmas) == 0:
self.mask = None # no mask needed if no steps (usually happens because strength==1.0)
if self.sampler_name in ["plms","ddim"]:
if mask is not None:
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
if self.sampler_name in ["plms","ddim"]:
# Callback function formated for compvis latent diffusion samplers
self.callback = self.img_callback_
else:
# Default callback function uses k-diffusion sampler variables
self.callback = self.k_callback_
self.verbose_print = print if verbose else lambda *args, **kwargs: None
def display_images(self, images):
images = images.double().cpu().add(1).div(2).clamp(0, 1)
images = torch.tensor(np.array(images))
grid = make_grid(images, 4).cpu()
display.clear_output(wait=True)
display.display(TF.to_pil_image(grid))
return
def view_sample_step(self, latents, path_name_modifier=''):
if self.save_sample_per_step:
samples = self.model.decode_first_stage(latents)
fname = f'{path_name_modifier}_{self.step_index:05}.png'
for i, sample in enumerate(samples):
sample = sample.double().cpu().add(1).div(2).clamp(0, 1)
sample = torch.tensor(np.array(sample))
grid = make_grid(sample, 4).cpu()
TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))
if self.show_sample_per_step:
samples = self.model.linear_decode(latents)
print(path_name_modifier)
self.display_images(samples)
return
# The callback function is applied to the image at each step
def dynamic_thresholding_(self, img, threshold):
# Dynamic thresholding from Imagen paper (May 2022)
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
s = np.max(np.append(s,1.0))
torch.clamp_(img, -1*s, s)
torch.FloatTensor.div_(img, s)
# Callback for samplers in the k-diffusion repo, called thus:
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
def k_callback_(self, args_dict):
self.step_index = args_dict['i']
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
init_noise = self.init_latent + self.noise * args_dict['sigma']
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
args_dict['x'].copy_(new_img)
self.view_sample_step(args_dict['denoised'], "x0_pred")
self.view_sample_step(args_dict['x'], "x")
# Callback for Compvis samplers
# Function that is called on the image (img) and step (i) at each step
def img_callback_(self, img, pred_x0, i):
self.step_index = i
# Thresholding functions
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(img, self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(img, -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
i_inv = len(self.sigmas) - i - 1
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise)
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
img.copy_(new_img)
self.view_sample_step(pred_x0, "x0_pred")
self.view_sample_step(img, "x")