Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| import skfmm | |
| from PIL import Image | |
| import torch.nn as nn | |
| import cv2 | |
| import scipy | |
| from scipy.ndimage.filters import gaussian_filter | |
| import kornia | |
| import warnings | |
| warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended") | |
| import alpha_clip | |
| from augmentations import ImageAugmentations | |
| from constants import Const, N | |
| def get_dist_field(dist_from, device, as_squeezed_np=False): | |
| if not isinstance(dist_from, np.ndarray): | |
| dist_from = dist_from.cpu().numpy() | |
| assert np.max(dist_from) <= 1 | |
| dist_from = -(np.where(dist_from, 0, -1) + 0.5) | |
| dist_field = skfmm.distance(dist_from, dx=1) | |
| if as_squeezed_np: | |
| return dist_field | |
| return torch.tensor(dist_field).to(device) | |
| def get_surround(surround_from, surround_width, device, as_squeezed_np=False): | |
| dists = get_dist_field(surround_from, device) | |
| surround = (dists <= surround_width).to(surround_from.dtype) | |
| if as_squeezed_np: | |
| return surround.cpu().numpy() | |
| return surround | |
| class DynMask: | |
| def __init__(self, click_pil, args, init_image_tensor, device, total_steps): | |
| self.args = args | |
| self.device = device | |
| self.init_image = init_image_tensor | |
| self.total_steps = total_steps | |
| self.ac_size = (self.args.alpha_clip_scale, self.args.alpha_clip_scale) | |
| if self.args.alpha_clip_scale == 336: | |
| self.ac_model, self.ac_preprocess = alpha_clip.load( | |
| "ViT-L/14@336px", | |
| alpha_vision_ckpt_pth="./checkpoints/clip_l14_336_grit1m_fultune_8xe.pth", | |
| device=self.device, | |
| ) | |
| else: | |
| self.ac_model, self.ac_preprocess = alpha_clip.load( | |
| "ViT-L/14", | |
| alpha_vision_ckpt_pth="./checkpoints/clip_l14_grit20m_fultune_2xe.pth", | |
| device=self.device, | |
| ) | |
| self.image_augmentations = ImageAugmentations( | |
| self.args.alpha_clip_scale, Const.AUG_NUM | |
| ) | |
| self.text_features = self.get_text_features([self.args.prompt]) | |
| self.latent_size = Const.LATENT_SIZE | |
| self.decoded_size = (Const.H, Const.W) | |
| self.thresh_val = Const.THRESH_VAL | |
| self.base_potential = None | |
| self.potential = None | |
| self.latent_mask = None | |
| self.set_init_masks(click_pil) | |
| self.cached_masks_clones = {} | |
| self.closs_hist = {} | |
| self.latents_hist = {} | |
| self.latent_masks_hist = {} | |
| def normalize_point_size(self, click, radius_for64=1.367): | |
| threshed = (click > 0.5).astype(float) | |
| x, y = np.where(threshed) | |
| center = int(x.mean().round()), int(y.mean().round()) | |
| norm_threshed = np.zeros_like(threshed) | |
| norm_threshed[center[0], center[1]] = 1 | |
| norm_threshed = get_surround( | |
| torch.tensor(norm_threshed).to(self.device), | |
| click.shape[0] / 64 * radius_for64 - 0.3, | |
| self.device, | |
| as_squeezed_np=True, | |
| ) | |
| return norm_threshed | |
| def calc_potential(self, click_pil, sigma_for_shape64): | |
| dest_size = self.latent_size | |
| click = click_pil.convert("L").resize(dest_size, Image.NEAREST) | |
| click = (np.array(click) > 125).astype(float) | |
| click = self.normalize_point_size( | |
| click, radius_for64=Const.POINT_ON_LATENT_RADIUS | |
| ) | |
| potential = gaussian_filter( | |
| click, sigma=sigma_for_shape64 * (click.shape[0]) / 64 | |
| ) | |
| potential = (potential - np.min(potential)) / max( | |
| np.max(potential) - np.min(potential), 1e-8 | |
| ) | |
| potential = potential[np.newaxis, np.newaxis, ...] | |
| potential = torch.from_numpy(potential).half().to(self.device) | |
| return potential | |
| def set_init_masks(self, click_pil, stretch_factor=1.0): | |
| potential = self.calc_potential( | |
| click_pil, sigma_for_shape64=Const.SIGMA_FOR_SHAPE64 | |
| ) | |
| self.base_potential = potential.detach().to(torch.float64) | |
| if self.base_potential.ndim == 2: | |
| self.base_potential = self.base_potential.unsqueeze(0).unsqueeze(0) | |
| self.base_potential = self.base_potential * (Const.POTENTIAL_PEAK - (-1)) - 1 | |
| self.base_potential = stretch_factor * self.base_potential | |
| self.set_cur_masks(step_i=0) | |
| def set_cur_masks( | |
| self, step_i, grads_to_update=None, surround_ring=None, return_only=None | |
| ): | |
| potential = self.base_potential + self.get_bias(step_i) | |
| if grads_to_update is not None: | |
| potential = potential + (surround_ring * Const.MASK_LR * grads_to_update) | |
| potential = transforms.GaussianBlur( | |
| Const.GAUSS_K_MASK, sigma=Const.GAUSS_SIGMA_MASK | |
| )(potential) | |
| if torch.all(potential <= 0): | |
| potential += Const.ADDITION_IN_COLLAPSE | |
| print( | |
| f"{'*' * 10} Mask shrunk entirely, added {Const.ADDITION_IN_COLLAPSE}" | |
| ) | |
| elif torch.all(potential >= 0): | |
| potential -= Const.ADDITION_IN_COLLAPSE | |
| print( | |
| f"{'*' * 10} Mask expanded entirely, reduced {Const.ADDITION_IN_COLLAPSE}" | |
| ) | |
| self.potential = potential.half() | |
| self.latent_mask = self.get_threshed_mask(self.potential) | |
| return self.get_curr_masks(return_only=return_only) | |
| def get_curr_masks(self, return_only=None): | |
| if return_only is not None: | |
| if return_only == N.POTENTIAL: | |
| return self.potential | |
| elif return_only == N.LATENT_MASK: | |
| return self.latent_mask | |
| else: | |
| raise ValueError(f"return_only should be in ('{N.POTENTIAL}', '{N.LATENT_MASK}')") | |
| return self.potential, self.latent_mask | |
| def make_cached_masks_clones(self, name): | |
| self.cached_masks_clones[name] = { | |
| N.POTENTIAL: self.potential.detach().clone(), | |
| N.LATENT_MASK: self.latent_mask.detach().clone(), | |
| } | |
| def set_masks_from_cached_masks_clones(self, name): | |
| self.potential = self.cached_masks_clones[name][N.POTENTIAL] | |
| self.latent_mask = self.cached_masks_clones[name][N.LATENT_MASK] | |
| def evolve_mask( | |
| self, step_i, decoder, latent_pred_z0, source_latents, return_only=None | |
| ): | |
| potential, latent_mask = self.get_curr_masks() | |
| surround_ring = self.get_ring(latent_mask) | |
| grads_latent = self.calc_grads( | |
| latent_pred_z0=latent_pred_z0, | |
| source_latents=source_latents, | |
| potential=potential, | |
| step_i=step_i, | |
| decoder=decoder, | |
| ) | |
| grads_latent = torch.abs(grads_latent) | |
| grads_latent = transforms.GaussianBlur( | |
| Const.GAUSS_K_GRADS, sigma=Const.GAUSS_SIGMA_GRADS | |
| )(grads_latent) | |
| grads_latent = (grads_latent - grads_latent.mean()) / max( | |
| grads_latent.std(), 1e-6 | |
| ) | |
| grads_latent = torch.maximum(grads_latent, torch.tensor(0.0).to(self.device)) | |
| self.set_cur_masks( | |
| step_i=step_i, grads_to_update=grads_latent, surround_ring=surround_ring | |
| ) | |
| return self.get_curr_masks(return_only=return_only) | |
| def calc_grads(self, latent_pred_z0, source_latents, potential, step_i, decoder): | |
| with torch.enable_grad(): | |
| latent_mask = self.get_threshed_mask(potential) | |
| latent_mask = latent_mask.detach().requires_grad_() | |
| blend_predz0_origz0 = latent_pred_z0 * latent_mask + ( | |
| source_latents * (1 - latent_mask) | |
| ) | |
| scaled_blend_pred_z0_origz0 = 1 / 0.18215 * blend_predz0_origz0 | |
| decoded_blend_predz0_origz0 = decoder( | |
| scaled_blend_pred_z0_origz0 | |
| ).sample.to(torch.float32) | |
| alpha_mask = transforms.Resize(self.decoded_size, interpolation=0)( | |
| latent_mask | |
| ) | |
| alpha_mask = (alpha_mask > 0.5).half().clone().detach() | |
| alpha_mask = get_surround( | |
| alpha_mask, | |
| Const.ALPHA_MASK_DILATION_ON_512 * (Const.HW / 512.0), | |
| self.device, | |
| ) | |
| alpha_loss = self.alpha_clip_loss( | |
| decoded_blend_predz0_origz0, | |
| alpha_mask, | |
| self.text_features, | |
| self.image_augmentations, | |
| augs_with_orig=True, | |
| ) | |
| self.closs_hist[ | |
| step_i - 1 | |
| ] = alpha_loss.detach() # The mask used for the loss is prev step mask | |
| grads_latent = torch.autograd.grad(alpha_loss, latent_mask)[0].to( | |
| torch.float64 | |
| ) | |
| return grads_latent.detach() | |
| def alpha_clip_loss( | |
| self, | |
| image, | |
| mask, | |
| text_features, | |
| image_augmentations, | |
| augs_with_orig=True, | |
| return_as_similarity=False, | |
| ): | |
| """ | |
| image and mask in range 0.0 to 1.0 | |
| """ | |
| assert mask.min() >= 0 and mask.max() <= 1 | |
| mask_transform = transforms.Compose( | |
| [nn.AdaptiveAvgPool2d(self.ac_size), transforms.Normalize(0.5, 0.26)] | |
| ) | |
| mask_normalize = transforms.Normalize(0.5, 0.26) | |
| image_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.ac_size, interpolation=Image.BICUBIC), | |
| transforms.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| image_normalize = transforms.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
| ) | |
| image = image.add(1).div(2) | |
| if image.ndim == 3: | |
| image = image.unsqueeze(0) | |
| alpha = mask | |
| if alpha.ndim == 3: | |
| alpha = alpha.unsqueeze(dim=0) | |
| if image_augmentations is not None: | |
| image, alpha = image_augmentations(image, alpha, with_orig=augs_with_orig) | |
| image = image_normalize(image).half() | |
| alpha = mask_normalize(alpha).half() | |
| else: | |
| image = image_transform(image).half() | |
| alpha = mask_transform(alpha).half() | |
| image_features = self.ac_model.visual(image, alpha) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| if return_as_similarity: | |
| alpha_loss = image_features @ text_features.T | |
| else: | |
| alpha_loss = 1 - image_features @ text_features.T | |
| alpha_loss = alpha_loss.mean(dim=0) | |
| return alpha_loss | |
| def get_text_features(self, prompt): | |
| assert type(prompt) in (list, tuple) | |
| text = alpha_clip.tokenize(prompt).to(self.device) | |
| text_features = self.ac_model.encode_text(text) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features | |
| def get_bias(self, step_i): | |
| bias = Const.BIAS_DILATION_VAL * (Const.BIAS_DILATION_DEC_FACTOR**step_i) | |
| while torch.all(self.base_potential + bias > 0) and bias > 1e-8: | |
| bias *= 0.9 | |
| return bias | |
| def get_threshed_mask(self, potential): | |
| thresh_val = self.thresh_val | |
| t_m = (potential > thresh_val).half() | |
| t_m = t_m.cpu().numpy().squeeze().astype(np.uint8) | |
| t_m = scipy.ndimage.binary_fill_holes(t_m) | |
| t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half() | |
| t_m = self.close_gaps_with_connection( | |
| t_m, thickness=Const.CLOSE_GAPS_WITH_CONNECTION_THICKNESS | |
| ) | |
| t_m = kornia.morphology.closing( | |
| t_m, torch.ones(Const.CLOSING_K, Const.CLOSING_K).to(self.device) | |
| ) | |
| t_m = t_m.cpu().numpy().squeeze().astype(np.uint8) | |
| t_m = scipy.ndimage.binary_fill_holes(t_m) | |
| t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half() | |
| t_m = transforms.GaussianBlur( | |
| Const.GAUSS_K_THRESHED, sigma=Const.GAUSS_SIGMA_THRESHED | |
| )(t_m) | |
| t_m = (t_m > Const.THRESH_POST_GAUSS).half() | |
| return t_m | |
| def close_gaps_with_connection(self, threshed_mask, thickness): | |
| # also cleans small contours | |
| given_threshed_mask = threshed_mask | |
| threshed_mask = threshed_mask.cpu().numpy().squeeze().astype(np.uint8) | |
| connected_mask = threshed_mask * 0 | |
| contours, hierarchy = cv2.findContours( | |
| threshed_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE | |
| ) | |
| if len(contours) == 1: | |
| return given_threshed_mask | |
| contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True) | |
| contours = [ | |
| cnt | |
| for cnt in contours | |
| if cv2.contourArea(cnt) | |
| > threshed_mask.shape[-1] * threshed_mask.shape[-2] * 0.001 | |
| ] | |
| cv2.drawContours(connected_mask, contours, 0, 255, -1) | |
| for i in range(1, len(contours)): | |
| cv2.drawContours(connected_mask, contours, i, 255, -1) | |
| hull = cv2.convexHull(contours[i]) # Convex hull of contour | |
| hull = cv2.approxPolyDP(hull, 0.1 * cv2.arcLength(hull, True), True) | |
| connect = hull.copy() | |
| for hp in hull: | |
| dists = np.linalg.norm(contours[0] - hp, axis=2).squeeze() | |
| min_points = np.where(dists == dists.min())[0] | |
| for mp in min_points: | |
| connect = np.append( | |
| connect, np.expand_dims(contours[0][mp], axis=0), axis=0 | |
| ) | |
| connected_mask = cv2.drawContours( | |
| connected_mask, [connect], -1, color=255, thickness=thickness | |
| ) | |
| connected_mask = cv2.drawContours( | |
| connected_mask, [connect], -1, color=255, thickness=-1 | |
| ) | |
| connected_mask = ( | |
| ((torch.tensor(connected_mask).to(self.device)) > 125) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| .half() | |
| ) | |
| return connected_mask | |
| def get_plain_dilated_latent_mask( | |
| self, | |
| last_step_latent_mask, | |
| step_i, | |
| total_steps, | |
| max_area_ratio_for_dilation=None, | |
| rerun_dyn_start_step_i=None, | |
| ): | |
| max_area_ratio_for_dilation = ( | |
| Const.MAX_AREA_RATIO_FOR_DILATION | |
| if max_area_ratio_for_dilation is None | |
| else max_area_ratio_for_dilation | |
| ) | |
| if ( | |
| last_step_latent_mask.sum() | |
| > max_area_ratio_for_dilation * last_step_latent_mask.nelement() | |
| ): | |
| return last_step_latent_mask | |
| first_k = self.latent_size[-1] // 2 | |
| while ( | |
| get_surround(last_step_latent_mask, first_k, self.device).sum() | |
| > 0.75 * self.latent_size[-1] ** 2 | |
| ): | |
| first_k -= 1 | |
| if rerun_dyn_start_step_i: | |
| plain_dilation_ws = np.linspace( | |
| first_k, 0, rerun_dyn_start_step_i + 2 - Const.RERUN_STOP_DILATION | |
| ).round() | |
| plain_dilation_ws = np.pad( | |
| plain_dilation_ws, (0, total_steps - len(plain_dilation_ws)) | |
| ) | |
| else: | |
| plain_dilation_ws = np.array( | |
| [first_k / max(1, (i / 3)) for i in range(0, total_steps)] | |
| ).round() | |
| plain_dilation_ws[-10:] = 0 | |
| return get_surround( | |
| last_step_latent_mask, plain_dilation_ws[step_i], self.device | |
| ).half() | |
| def get_ring(self, latent_mask): | |
| assert (latent_mask.min() >= 0) and (latent_mask.max() <= 1) | |
| out_ring_width = Const.OUT_RING_WIDTH | |
| in_on_ring_width = Const.IN_ON_RING_WIDTH | |
| latent_mask = (latent_mask.cpu().numpy() >= 0.5).astype(np.float16) | |
| dists = get_dist_field(latent_mask, self.device, as_squeezed_np=True) | |
| in_ring_width = in_on_ring_width - 1 | |
| in_ring = dists.copy() | |
| in_ring[in_ring > -1] = 0 | |
| in_ring[in_ring <= -in_ring_width - 1] = 0 | |
| in_ring[in_ring != 0] = 1 | |
| on_ring = latent_mask.copy() | |
| on_ring[dists < -1] = 0 | |
| in_on_ring = in_ring.astype(bool) | on_ring.astype(bool) | |
| out_ring = dists.copy() | |
| out_ring[out_ring <= 0] = 0 | |
| out_ring[out_ring > out_ring_width] = 0 | |
| out_ring[out_ring != 0] = 1 | |
| surround_ring = in_on_ring.astype(np.uint8) | out_ring.astype(np.uint8) | |
| surround_ring = torch.tensor(surround_ring).to(self.device) | |
| return surround_ring | |