import re import pandas as pd import numpy as np from functools import reduce from einops import rearrange import sys from skimage.exposure import match_histograms import random from pytorch_lightning import seed_everything import time import torch from torch import autocast from PIL import Image import os import requests from omegaconf import OmegaConf from huggingface_hub import hf_hub_download # 1. Download stable diffusion repository and set path os.system("git clone https://github.com/kael558/stable-diffusion-cpu") os.system("git clone https://github.com/shariqfarooq123/AdaBins.git") os.system("git clone https://github.com/isl-org/MiDaS.git") os.system("git clone https://github.com/MSFTserver/pytorch3d-lite.git") os.system("git clone https://github.com/deforum/k-diffusion/") with open('k-diffusion/k_diffusion/__init__.py', 'w') as f: f.write('') sys.path.extend([ './taming-transformers', './clip', 'stable-diffusion/', 'k-diffusion', 'pytorch3d-lite', 'AdaBins', 'MiDaS', ]) from helpers import sampler_fn from k_diffusion.external import CompVisDenoiser from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler # 2. Set model download config def load_model_from_config(config, ckpt, verbose=False, half_precision=False): map_location = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location=map_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd_var = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd_var, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if half_precision: model = model.half() model.eval() return model print('Model loading...') ''' models_path = "models" model_checkpoint = "sd-v1-4.ckpt" ckpt_path = os.path.join(models_path, model_checkpoint) ckpt_path = snapshot_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt") if os.path.exists(ckpt_path): print(f"{ckpt_path} exists") else: print(f"Attempting to download {model_checkpoint}...this may take a while") url = 'https://kael558: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@huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt' ckpt_request = requests.get(url) print('Model downloaded.') with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file: model_file.write(ckpt_request.content) ''' ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt") ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" local_config = OmegaConf.load(f"{ckpt_config_path}") model = load_model_from_config(local_config, f"{ckpt_path}" ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) print('Model saved.') class DeformAnimKeys(): def __init__(self, anim_args): self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames) self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames) self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames) self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames) self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames) self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames) self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames) self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames) self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames) self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames) self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames) self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames) self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames) self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames) self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames) def check_is_number(value): float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' return re.match(float_pattern, value) def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'): import numexpr key_frame_series = pd.Series([np.nan for a in range(max_frames)]) for i in range(0, max_frames): if i in key_frames: value = key_frames[i] value_is_number = check_is_number(value) # if it's only a number, leave the rest for the default interpolation if value_is_number: t = i key_frame_series[i] = value if not value_is_number: t = i key_frame_series[i] = numexpr.evaluate(value) key_frame_series = key_frame_series.astype(float) if interp_method == 'Cubic' and len(key_frames.items()) <= 3: interp_method = 'Quadratic' if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: interp_method = 'Linear' key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') if integer: return key_frame_series.astype(int) return key_frame_series def parse_key_frames(string, prompt_parser=None): # because math functions (i.e. sin(t)) can utilize brackets # it extracts the value in form of some stuff # which has previously been enclosed with brackets and # with a comma or end of line existing after the closing one pattern = r'((?P[0-9]+):[\s]*\((?P[\S\s]*?)\)([,][\s]?|[\s]?$))' frames = dict() for match_object in re.finditer(pattern, string): frame = int(match_object.groupdict()['frame']) param = match_object.groupdict()['param'] if prompt_parser: frames[frame] = prompt_parser(param) else: frames[frame] = param if frames == {} and len(string) != 0: raise RuntimeError('Key Frame string not correctly formatted') return frames # https://en.wikipedia.org/wiki/Rotation_matrix def getRotationMatrixManual(rotation_angles): rotation_angles = [np.deg2rad(x) for x in rotation_angles] phi = rotation_angles[0] # around x gamma = rotation_angles[1] # around y theta = rotation_angles[2] # around z # X rotation Rphi = np.eye(4, 4) sp = np.sin(phi) cp = np.cos(phi) Rphi[1, 1] = cp Rphi[2, 2] = Rphi[1, 1] Rphi[1, 2] = -sp Rphi[2, 1] = sp # Y rotation Rgamma = np.eye(4, 4) sg = np.sin(gamma) cg = np.cos(gamma) Rgamma[0, 0] = cg Rgamma[2, 2] = Rgamma[0, 0] Rgamma[0, 2] = sg Rgamma[2, 0] = -sg # Z rotation (in-image-plane) Rtheta = np.eye(4, 4) st = np.sin(theta) ct = np.cos(theta) Rtheta[0, 0] = ct Rtheta[1, 1] = Rtheta[0, 0] Rtheta[0, 1] = -st Rtheta[1, 0] = st R = reduce(lambda x, y: np.matmul(x, y), [Rphi, Rgamma, Rtheta]) return R def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength): ptsIn2D = ptsIn[0, :] ptsOut2D = ptsOut[0, :] ptsOut2Dlist = [] ptsIn2Dlist = [] for i in range(0, 4): ptsOut2Dlist.append([ptsOut2D[i, 0], ptsOut2D[i, 1]]) ptsIn2Dlist.append([ptsIn2D[i, 0], ptsIn2D[i, 1]]) pin = np.array(ptsIn2Dlist) + [W / 2., H / 2.] pout = (np.array(ptsOut2Dlist) + [1., 1.]) * (0.5 * sidelength) pin = pin.astype(np.float32) pout = pout.astype(np.float32) return pin, pout def warpMatrix(W, H, theta, phi, gamma, scale, fV): # M is to be estimated M = np.eye(4, 4) fVhalf = np.deg2rad(fV / 2.) d = np.sqrt(W * W + H * H) sideLength = scale * d / np.cos(fVhalf) h = d / (2.0 * np.sin(fVhalf)) n = h - (d / 2.0) f = h + (d / 2.0) # Translation along Z-axis by -h T = np.eye(4, 4) T[2, 3] = -h # Rotation matrices around x,y,z R = getRotationMatrixManual([phi, gamma, theta]) # Projection Matrix P = np.eye(4, 4) P[0, 0] = 1.0 / np.tan(fVhalf) P[1, 1] = P[0, 0] P[2, 2] = -(f + n) / (f - n) P[2, 3] = -(2.0 * f * n) / (f - n) P[3, 2] = -1.0 # pythonic matrix multiplication F = reduce(lambda x, y: np.matmul(x, y), [P, T, R]) # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3); ptsIn = np.array([[ [-W / 2., H / 2., 0.], [W / 2., H / 2., 0.], [W / 2., -H / 2., 0.], [-W / 2., -H / 2., 0.] ]]) ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype)) ptsOut = cv2.perspectiveTransform(ptsIn, F) ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength) # check float32 otherwise OpenCV throws an error assert (ptsInPt2f.dtype == np.float32) assert (ptsOutPt2f.dtype == np.float32) M33 = cv2.getPerspectiveTransform(ptsInPt2f, ptsOutPt2f) return M33, sideLength def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): angle = keys.angle_series[frame_idx] zoom = keys.zoom_series[frame_idx] translation_x = keys.translation_x_series[frame_idx] translation_y = keys.translation_y_series[frame_idx] center = (args.W // 2, args.H // 2) trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) trans_mat = np.vstack([trans_mat, [0, 0, 1]]) rot_mat = np.vstack([rot_mat, [0, 0, 1]]) xform = np.matmul(rot_mat, trans_mat) return cv2.warpPerspective( prev_img_cv2, xform, (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE ) def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: sample = ((sample.astype(float) / 255.0) * 2) - 1 sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) sample = torch.from_numpy(sample) return sample def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) sample_int8 = (sample_f32 * 255) return sample_int8.astype(type) def maintain_colors(prev_img, color_match_sample, mode): if mode == 'Match Frame 0 RGB': return match_histograms(prev_img, color_match_sample, multichannel=True) elif mode == 'Match Frame 0 HSV': prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) else: # Match Frame 0 LAB prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: return sample + torch.randn(sample.shape, device=sample.device) * noise_amt def next_seed(args): if args.seed_behavior == 'iter': args.seed += 1 elif args.seed_behavior == 'fixed': pass # always keep seed the same else: args.seed = random.randint(0, 2**32 - 1) return args.seed def generate(args, return_c=False): seed_everything(args.seed) sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model) model_wrap = CompVisDenoiser(model) batch_size = args.n_samples prompt = args.prompt assert prompt is not None data = [batch_size * [prompt]] precision_scope = autocast mask = None t_enc = int((1.0-args.strength) * args.steps) init_latent = None # Noise schedule for the k-diffusion samplers (used for masking) k_sigmas = model_wrap.get_sigmas(args.steps) k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] if args.sampler in ['plms','ddim']: sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) callback = SamplerCallback(args=args, mask=mask, init_latent=init_latent, sigmas=k_sigmas, sampler=sampler, verbose=False).callback results = [] with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): for prompts in data: uc = model.get_learned_conditioning(batch_size * [""]) c = model.get_learned_conditioning(prompts) if args.scale == 1.0: uc = None if args.init_c != None: c = args.init_c if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: samples = sampler_fn( c=c, uc=uc, args=args, model_wrap=model_wrap, init_latent=init_latent, t_enc=t_enc, device=device, cb=callback) else: # args.sampler == 'plms' or args.sampler == 'ddim': if init_latent is not None and args.strength > 0: z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) else: z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) if args.sampler == 'ddim': samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale, unconditional_conditioning=uc, img_callback=callback) elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" shape = [args.C, args.H // args.f, args.W // args.f] samples, _ = sampler.sample(S=args.steps, conditioning=c, batch_size=args.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=args.scale, unconditional_conditioning=uc, eta=args.ddim_eta, x_T=z_enc, img_callback=callback) else: raise Exception(f"Sampler {args.sampler} not recognised.") x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if return_c: results.append(c.clone()) for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') image = Image.fromarray(x_sample.astype(np.uint8)) results.append(image) return results def get_output_folder(output_path, batch_folder): out_path = os.path.join(output_path,time.strftime('%Y-%m')) if batch_folder != "": out_path = os.path.join(out_path, batch_folder) os.makedirs(out_path, exist_ok=True) return out_path def DeforumArgs(): #@markdown **Image Settings** W = 512 #@param H = 512 #@param W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 #@markdown **Sampling Settings** seed = -1 #@param sampler = 'dpm2_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] steps = 50 #@param scale = 7 #@param ddim_eta = 0.0 #@param dynamic_threshold = None static_threshold = None #@markdown **Save & Display Settings** save_samples = True #@param {type:"boolean"} save_settings = True #@param {type:"boolean"} display_samples = True #@param {type:"boolean"} save_sample_per_step = False #@param {type:"boolean"} show_sample_per_step = False #@param {type:"boolean"} #@markdown **Prompt Settings** prompt_weighting = False #@param {type:"boolean"} normalize_prompt_weights = True #@param {type:"boolean"} log_weighted_subprompts = False #@param {type:"boolean"} #@markdown **Batch Settings** n_batch = 1 #@param batch_name = "StableFun" #@param {type:"string"} filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] seed_behavior = "fixed" #@param ["iter","fixed","random"] make_grid = False #@param {type:"boolean"} grid_rows = 2 #@param #@markdown **Init Settings** use_init = False #@param {type:"boolean"} strength = 0.0 #@param {type:"number"} strength_0_no_init = True # Set the strength to 0 automatically when no init image is used init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"} # Whiter areas of the mask are areas that change more use_mask = False #@param {type:"boolean"} use_alpha_as_mask = False # use the alpha channel of the init image as the mask mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"} invert_mask = False #@param {type:"boolean"} # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. mask_brightness_adjust = 1.0 #@param {type:"number"} mask_contrast_adjust = 1.0 #@param {type:"number"} # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding overlay_mask = True # {type:"boolean"} # Blur edges of final overlay mask, if used. Minimum = 0 (no blur) mask_overlay_blur = 5 # {type:"number"} n_samples = 1 # doesnt do anything precision = 'autocast' C = 4 f = 8 prompt = "" timestring = "" init_latent = None init_sample = None init_c = None return locals() def DeforumAnimArgs(): #@markdown ####**Animation:** animation_mode = '2D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} max_frames = 100 #@param {type:"number"} border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'} #@markdown ####**Motion Parameters:** angle = "0:(0)"#@param {type:"string"} zoom = "0:(1.04)"#@param {type:"string"} translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"} translation_y = "0:(0)"#@param {type:"string"} translation_z = "0:(10)"#@param {type:"string"} rotation_3d_x = "0:(0)"#@param {type:"string"} rotation_3d_y = "0:(0)"#@param {type:"string"} rotation_3d_z = "0:(0)"#@param {type:"string"} flip_2d_perspective = False #@param {type:"boolean"} perspective_flip_theta = "0:(0)"#@param {type:"string"} perspective_flip_phi = "0:(t%15)"#@param {type:"string"} perspective_flip_gamma = "0:(0)"#@param {type:"string"} perspective_flip_fv = "0:(53)"#@param {type:"string"} noise_schedule = "0: (0.02)"#@param {type:"string"} strength_schedule = "0: (0.65)"#@param {type:"string"} contrast_schedule = "0: (1.0)"#@param {type:"string"} #@markdown ####**Coherence:** color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} diffusion_cadence = '7' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} #@markdown ####**3D Depth Warping:** use_depth_warping = True #@param {type:"boolean"} midas_weight = 0.3#@param {type:"number"} near_plane = 200 far_plane = 10000 fov = 40#@param {type:"number"} padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} save_depth_maps = False #@param {type:"boolean"} #@markdown ####**Video Input:** video_init_path ='/content/video_in.mp4'#@param {type:"string"} extract_nth_frame = 1#@param {type:"number"} overwrite_extracted_frames = True #@param {type:"boolean"} use_mask_video = False #@param {type:"boolean"} video_mask_path ='/content/video_in.mp4'#@param {type:"string"} #@markdown ####**Interpolation:** interpolate_key_frames = False #@param {type:"boolean"} interpolate_x_frames = 4 #@param {type:"number"} #@markdown ####**Resume Animation:** resume_from_timestring = False #@param {type:"boolean"} resume_timestring = "20220829210106" #@param {type:"string"} return locals() # # Callback functions # class SamplerCallback(object): # Creates the callback function to be passed into the samplers for each step def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None, verbose=False): self.sampler_name = args.sampler self.dynamic_threshold = args.dynamic_threshold self.static_threshold = args.static_threshold self.mask = mask self.init_latent = init_latent self.sigmas = sigmas self.sampler = sampler self.verbose = verbose self.batch_size = args.n_samples #self.save_sample_per_step = args.save_sample_per_step #self.show_sample_per_step = args.show_sample_per_step #self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] #if self.save_sample_per_step: # for path in self.paths_to_image_steps: # os.makedirs(path, exist_ok=True) self.step_index = 0 self.noise = None if init_latent is not None: self.noise = torch.randn_like(init_latent, device=device) self.mask_schedule = None if sigmas is not None and len(sigmas) > 0: self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) elif len(sigmas) == 0: self.mask = None # no mask needed if no steps (usually happens because strength==1.0) if self.sampler_name in ["plms","ddim"]: if mask is not None: assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" if self.sampler_name in ["plms","ddim"]: # Callback function formated for compvis latent diffusion samplers self.callback = self.img_callback_ else: # Default callback function uses k-diffusion sampler variables self.callback = self.k_callback_ #self.verbose_print = print if verbose else lambda *args, **kwargs: None # The callback function is applied to the image at each step def dynamic_thresholding_(self, img, threshold): # Dynamic thresholding from Imagen paper (May 2022) s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) s = np.max(np.append(s,1.0)) torch.clamp_(img, -1*s, s) torch.FloatTensor.div_(img, s) # Callback for samplers in the k-diffusion repo, called thus: # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) def k_callback_(self, args_dict): self.step_index = args_dict['i'] if self.dynamic_threshold is not None: self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) if self.static_threshold is not None: torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) if self.mask is not None: init_noise = self.init_latent + self.noise * args_dict['sigma'] is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) args_dict['x'].copy_(new_img) #self.view_sample_step(args_dict['denoised'], "x0_pred") # Callback for Compvis samplers # Function that is called on the image (img) and step (i) at each step def img_callback_(self, img, i): self.step_index = i # Thresholding functions if self.dynamic_threshold is not None: self.dynamic_thresholding_(img, self.dynamic_threshold) if self.static_threshold is not None: torch.clamp_(img, -1*self.static_threshold, self.static_threshold) if self.mask is not None: i_inv = len(self.sigmas) - i - 1 init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise) is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) img.copy_(new_img) #self.view_sample_step(img, "x")