|
|
import torch |
|
|
import os |
|
|
import torchvision.transforms.functional as TF |
|
|
from torchvision.utils import make_grid |
|
|
import numpy as np |
|
|
from IPython import display |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SamplerCallback(object): |
|
|
|
|
|
def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None, |
|
|
verbose=False): |
|
|
self.model = root.model |
|
|
self.device = root.device |
|
|
self.sampler_name = args.sampler |
|
|
self.dynamic_threshold = args.dynamic_threshold |
|
|
self.static_threshold = args.static_threshold |
|
|
self.mask = mask |
|
|
self.init_latent = init_latent |
|
|
self.sigmas = sigmas |
|
|
self.sampler = sampler |
|
|
self.verbose = verbose |
|
|
self.batch_size = args.n_samples |
|
|
self.save_sample_per_step = args.save_sample_per_step |
|
|
self.show_sample_per_step = args.show_sample_per_step |
|
|
self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] |
|
|
|
|
|
if self.save_sample_per_step: |
|
|
for path in self.paths_to_image_steps: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
self.step_index = 0 |
|
|
|
|
|
self.noise = None |
|
|
if init_latent is not None: |
|
|
self.noise = torch.randn_like(init_latent, device=self.device) |
|
|
|
|
|
self.mask_schedule = None |
|
|
if sigmas is not None and len(sigmas) > 0: |
|
|
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) |
|
|
elif len(sigmas) == 0: |
|
|
self.mask = None |
|
|
|
|
|
if self.sampler_name in ["plms","ddim"]: |
|
|
if mask is not None: |
|
|
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" |
|
|
|
|
|
if self.sampler_name in ["plms","ddim"]: |
|
|
|
|
|
self.callback = self.img_callback_ |
|
|
else: |
|
|
|
|
|
self.callback = self.k_callback_ |
|
|
|
|
|
self.verbose_print = print if verbose else lambda *args, **kwargs: None |
|
|
|
|
|
def display_images(self, images): |
|
|
images = images.double().cpu().add(1).div(2).clamp(0, 1) |
|
|
images = torch.tensor(np.array(images)) |
|
|
grid = make_grid(images, 4).cpu() |
|
|
display.clear_output(wait=True) |
|
|
display.display(TF.to_pil_image(grid)) |
|
|
return |
|
|
|
|
|
def view_sample_step(self, latents, path_name_modifier=''): |
|
|
if self.save_sample_per_step: |
|
|
samples = self.model.decode_first_stage(latents) |
|
|
fname = f'{path_name_modifier}_{self.step_index:05}.png' |
|
|
for i, sample in enumerate(samples): |
|
|
sample = sample.double().cpu().add(1).div(2).clamp(0, 1) |
|
|
sample = torch.tensor(np.array(sample)) |
|
|
grid = make_grid(sample, 4).cpu() |
|
|
TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname)) |
|
|
if self.show_sample_per_step: |
|
|
samples = self.model.linear_decode(latents) |
|
|
print(path_name_modifier) |
|
|
self.display_images(samples) |
|
|
return |
|
|
|
|
|
|
|
|
def dynamic_thresholding_(self, img, threshold): |
|
|
|
|
|
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) |
|
|
s = np.max(np.append(s,1.0)) |
|
|
torch.clamp_(img, -1*s, s) |
|
|
torch.FloatTensor.div_(img, s) |
|
|
|
|
|
|
|
|
|
|
|
def k_callback_(self, args_dict): |
|
|
self.step_index = args_dict['i'] |
|
|
if self.dynamic_threshold is not None: |
|
|
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) |
|
|
if self.static_threshold is not None: |
|
|
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) |
|
|
if self.mask is not None: |
|
|
init_noise = self.init_latent + self.noise * args_dict['sigma'] |
|
|
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) |
|
|
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) |
|
|
args_dict['x'].copy_(new_img) |
|
|
|
|
|
self.view_sample_step(args_dict['denoised'], "x0_pred") |
|
|
self.view_sample_step(args_dict['x'], "x") |
|
|
|
|
|
|
|
|
|
|
|
def img_callback_(self, img, pred_x0, i): |
|
|
self.step_index = i |
|
|
|
|
|
if self.dynamic_threshold is not None: |
|
|
self.dynamic_thresholding_(img, self.dynamic_threshold) |
|
|
if self.static_threshold is not None: |
|
|
torch.clamp_(img, -1*self.static_threshold, self.static_threshold) |
|
|
if self.mask is not None: |
|
|
i_inv = len(self.sigmas) - i - 1 |
|
|
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise) |
|
|
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) |
|
|
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) |
|
|
img.copy_(new_img) |
|
|
|
|
|
self.view_sample_step(pred_x0, "x0_pred") |
|
|
self.view_sample_step(img, "x") |