Spaces:
Runtime error
Runtime error
| import abc | |
| from typing import Optional | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from loguru import logger | |
| from iopaint.helper import ( | |
| boxes_from_mask, | |
| resize_max_size, | |
| pad_img_to_modulo, | |
| switch_mps_device, | |
| ) | |
| from iopaint.schema import InpaintRequest, HDStrategy, SDSampler | |
| from .helper.g_diffuser_bot import expand_image | |
| from .utils import get_scheduler | |
| class InpaintModel: | |
| name = "base" | |
| min_size: Optional[int] = None | |
| pad_mod = 8 | |
| pad_to_square = False | |
| is_erase_model = False | |
| def __init__(self, device, **kwargs): | |
| """ | |
| Args: | |
| device: | |
| """ | |
| device = switch_mps_device(self.name, device) | |
| self.device = device | |
| self.init_model(device, **kwargs) | |
| def init_model(self, device, **kwargs): | |
| ... | |
| def is_downloaded() -> bool: | |
| return False | |
| def forward(self, image, mask, config: InpaintRequest): | |
| """Input images and output images have same size | |
| images: [H, W, C] RGB | |
| masks: [H, W, 1] 255 为 masks 区域 | |
| return: BGR IMAGE | |
| """ | |
| ... | |
| def download(): | |
| ... | |
| def _pad_forward(self, image, mask, config: InpaintRequest): | |
| origin_height, origin_width = image.shape[:2] | |
| pad_image = pad_img_to_modulo( | |
| image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size | |
| ) | |
| pad_mask = pad_img_to_modulo( | |
| mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size | |
| ) | |
| # logger.info(f"final forward pad size: {pad_image.shape}") | |
| image, mask = self.forward_pre_process(image, mask, config) | |
| result = self.forward(pad_image, pad_mask, config) | |
| result = result[0:origin_height, 0:origin_width, :] | |
| result, image, mask = self.forward_post_process(result, image, mask, config) | |
| if config.sd_keep_unmasked_area: | |
| mask = mask[:, :, np.newaxis] | |
| result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) | |
| return result | |
| def forward_pre_process(self, image, mask, config): | |
| return image, mask | |
| def forward_post_process(self, result, image, mask, config): | |
| return result, image, mask | |
| def __call__(self, image, mask, config: InpaintRequest): | |
| """ | |
| images: [H, W, C] RGB, not normalized | |
| masks: [H, W] | |
| return: BGR IMAGE | |
| """ | |
| inpaint_result = None | |
| # logger.info(f"hd_strategy: {config.hd_strategy}") | |
| if config.hd_strategy == HDStrategy.CROP: | |
| if max(image.shape) > config.hd_strategy_crop_trigger_size: | |
| logger.info(f"Run crop strategy") | |
| boxes = boxes_from_mask(mask) | |
| crop_result = [] | |
| for box in boxes: | |
| crop_image, crop_box = self._run_box(image, mask, box, config) | |
| crop_result.append((crop_image, crop_box)) | |
| inpaint_result = image[:, :, ::-1] | |
| for crop_image, crop_box in crop_result: | |
| x1, y1, x2, y2 = crop_box | |
| inpaint_result[y1:y2, x1:x2, :] = crop_image | |
| elif config.hd_strategy == HDStrategy.RESIZE: | |
| if max(image.shape) > config.hd_strategy_resize_limit: | |
| origin_size = image.shape[:2] | |
| downsize_image = resize_max_size( | |
| image, size_limit=config.hd_strategy_resize_limit | |
| ) | |
| downsize_mask = resize_max_size( | |
| mask, size_limit=config.hd_strategy_resize_limit | |
| ) | |
| logger.info( | |
| f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" | |
| ) | |
| inpaint_result = self._pad_forward( | |
| downsize_image, downsize_mask, config | |
| ) | |
| # only paste masked area result | |
| inpaint_result = cv2.resize( | |
| inpaint_result, | |
| (origin_size[1], origin_size[0]), | |
| interpolation=cv2.INTER_CUBIC, | |
| ) | |
| original_pixel_indices = mask < 127 | |
| inpaint_result[original_pixel_indices] = image[:, :, ::-1][ | |
| original_pixel_indices | |
| ] | |
| if inpaint_result is None: | |
| inpaint_result = self._pad_forward(image, mask, config) | |
| return inpaint_result | |
| def _crop_box(self, image, mask, box, config: InpaintRequest): | |
| """ | |
| Args: | |
| image: [H, W, C] RGB | |
| mask: [H, W, 1] | |
| box: [left,top,right,bottom] | |
| Returns: | |
| BGR IMAGE, (l, r, r, b) | |
| """ | |
| box_h = box[3] - box[1] | |
| box_w = box[2] - box[0] | |
| cx = (box[0] + box[2]) // 2 | |
| cy = (box[1] + box[3]) // 2 | |
| img_h, img_w = image.shape[:2] | |
| w = box_w + config.hd_strategy_crop_margin * 2 | |
| h = box_h + config.hd_strategy_crop_margin * 2 | |
| _l = cx - w // 2 | |
| _r = cx + w // 2 | |
| _t = cy - h // 2 | |
| _b = cy + h // 2 | |
| l = max(_l, 0) | |
| r = min(_r, img_w) | |
| t = max(_t, 0) | |
| b = min(_b, img_h) | |
| # try to get more context when crop around image edge | |
| if _l < 0: | |
| r += abs(_l) | |
| if _r > img_w: | |
| l -= _r - img_w | |
| if _t < 0: | |
| b += abs(_t) | |
| if _b > img_h: | |
| t -= _b - img_h | |
| l = max(l, 0) | |
| r = min(r, img_w) | |
| t = max(t, 0) | |
| b = min(b, img_h) | |
| crop_img = image[t:b, l:r, :] | |
| crop_mask = mask[t:b, l:r] | |
| # logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") | |
| return crop_img, crop_mask, [l, t, r, b] | |
| def _calculate_cdf(self, histogram): | |
| cdf = histogram.cumsum() | |
| normalized_cdf = cdf / float(cdf.max()) | |
| return normalized_cdf | |
| def _calculate_lookup(self, source_cdf, reference_cdf): | |
| lookup_table = np.zeros(256) | |
| lookup_val = 0 | |
| for source_index, source_val in enumerate(source_cdf): | |
| for reference_index, reference_val in enumerate(reference_cdf): | |
| if reference_val >= source_val: | |
| lookup_val = reference_index | |
| break | |
| lookup_table[source_index] = lookup_val | |
| return lookup_table | |
| def _match_histograms(self, source, reference, mask): | |
| transformed_channels = [] | |
| if len(mask.shape) == 3: | |
| mask = mask[:, :, -1] | |
| for channel in range(source.shape[-1]): | |
| source_channel = source[:, :, channel] | |
| reference_channel = reference[:, :, channel] | |
| # only calculate histograms for non-masked parts | |
| source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) | |
| reference_histogram, _ = np.histogram( | |
| reference_channel[mask == 0], 256, [0, 256] | |
| ) | |
| source_cdf = self._calculate_cdf(source_histogram) | |
| reference_cdf = self._calculate_cdf(reference_histogram) | |
| lookup = self._calculate_lookup(source_cdf, reference_cdf) | |
| transformed_channels.append(cv2.LUT(source_channel, lookup)) | |
| result = cv2.merge(transformed_channels) | |
| result = cv2.convertScaleAbs(result) | |
| return result | |
| def _apply_cropper(self, image, mask, config: InpaintRequest): | |
| img_h, img_w = image.shape[:2] | |
| l, t, w, h = ( | |
| config.croper_x, | |
| config.croper_y, | |
| config.croper_width, | |
| config.croper_height, | |
| ) | |
| r = l + w | |
| b = t + h | |
| l = max(l, 0) | |
| r = min(r, img_w) | |
| t = max(t, 0) | |
| b = min(b, img_h) | |
| crop_img = image[t:b, l:r, :] | |
| crop_mask = mask[t:b, l:r] | |
| return crop_img, crop_mask, (l, t, r, b) | |
| def _run_box(self, image, mask, box, config: InpaintRequest): | |
| """ | |
| Args: | |
| image: [H, W, C] RGB | |
| mask: [H, W, 1] | |
| box: [left,top,right,bottom] | |
| Returns: | |
| BGR IMAGE | |
| """ | |
| crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) | |
| return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] | |
| class DiffusionInpaintModel(InpaintModel): | |
| def __init__(self, device, **kwargs): | |
| self.model_info = kwargs["model_info"] | |
| self.model_id_or_path = self.model_info.path | |
| super().__init__(device, **kwargs) | |
| def __call__(self, image, mask, config: InpaintRequest): | |
| """ | |
| images: [H, W, C] RGB, not normalized | |
| masks: [H, W] | |
| return: BGR IMAGE | |
| """ | |
| # boxes = boxes_from_mask(mask) | |
| if config.use_croper: | |
| crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) | |
| crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) | |
| inpaint_result = image[:, :, ::-1] | |
| inpaint_result[t:b, l:r, :] = crop_image | |
| elif config.use_extender: | |
| inpaint_result = self._do_outpainting(image, config) | |
| else: | |
| inpaint_result = self._scaled_pad_forward(image, mask, config) | |
| return inpaint_result | |
| def _do_outpainting(self, image, config: InpaintRequest): | |
| # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数 | |
| # 从 image 中 crop 出 outpainting 区域 | |
| image_h, image_w = image.shape[:2] | |
| cropper_l = config.extender_x | |
| cropper_t = config.extender_y | |
| cropper_r = config.extender_x + config.extender_width | |
| cropper_b = config.extender_y + config.extender_height | |
| image_l = 0 | |
| image_t = 0 | |
| image_r = image_w | |
| image_b = image_h | |
| # 类似求 IOU | |
| l = max(cropper_l, image_l) | |
| t = max(cropper_t, image_t) | |
| r = min(cropper_r, image_r) | |
| b = min(cropper_b, image_b) | |
| assert ( | |
| 0 <= l < r and 0 <= t < b | |
| ), f"cropper and image not overlap, {l},{t},{r},{b}" | |
| cropped_image = image[t:b, l:r, :] | |
| padding_l = max(0, image_l - cropper_l) | |
| padding_t = max(0, image_t - cropper_t) | |
| padding_r = max(0, cropper_r - image_r) | |
| padding_b = max(0, cropper_b - image_b) | |
| expanded_image, mask_image = expand_image( | |
| cropped_image, | |
| left=padding_l, | |
| top=padding_t, | |
| right=padding_r, | |
| bottom=padding_b, | |
| softness=config.sd_outpainting_softness, | |
| space=config.sd_outpainting_space, | |
| ) | |
| # 最终扩大了的 image, BGR | |
| expanded_cropped_result_image = self._scaled_pad_forward( | |
| expanded_image, mask_image, config | |
| ) | |
| # RGB -> BGR | |
| outpainting_image = cv2.copyMakeBorder( | |
| image, | |
| left=padding_l, | |
| top=padding_t, | |
| right=padding_r, | |
| bottom=padding_b, | |
| borderType=cv2.BORDER_CONSTANT, | |
| value=0, | |
| )[:, :, ::-1] | |
| # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend | |
| paste_t = 0 if config.extender_y < 0 else config.extender_y | |
| paste_l = 0 if config.extender_x < 0 else config.extender_x | |
| outpainting_image[ | |
| paste_t : paste_t + expanded_cropped_result_image.shape[0], | |
| paste_l : paste_l + expanded_cropped_result_image.shape[1], | |
| :, | |
| ] = expanded_cropped_result_image | |
| return outpainting_image | |
| def _scaled_pad_forward(self, image, mask, config: InpaintRequest): | |
| longer_side_length = int(config.sd_scale * max(image.shape[:2])) | |
| origin_size = image.shape[:2] | |
| downsize_image = resize_max_size(image, size_limit=longer_side_length) | |
| downsize_mask = resize_max_size(mask, size_limit=longer_side_length) | |
| if config.sd_scale != 1: | |
| logger.info( | |
| f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" | |
| ) | |
| inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) | |
| # only paste masked area result | |
| inpaint_result = cv2.resize( | |
| inpaint_result, | |
| (origin_size[1], origin_size[0]), | |
| interpolation=cv2.INTER_CUBIC, | |
| ) | |
| # blend result, copy from g_diffuser_bot | |
| # mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0) | |
| # inpaint_result = np.clip( | |
| # inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0 | |
| # ) | |
| # original_pixel_indices = mask < 127 | |
| # inpaint_result[original_pixel_indices] = image[:, :, ::-1][ | |
| # original_pixel_indices | |
| # ] | |
| return inpaint_result | |
| def set_scheduler(self, config: InpaintRequest): | |
| scheduler_config = self.model.scheduler.config | |
| sd_sampler = config.sd_sampler | |
| if config.sd_lcm_lora and self.model_info.support_lcm_lora: | |
| sd_sampler = SDSampler.lcm | |
| logger.info(f"LCM Lora enabled, use {sd_sampler} sampler") | |
| scheduler = get_scheduler(sd_sampler, scheduler_config) | |
| self.model.scheduler = scheduler | |
| def forward_pre_process(self, image, mask, config): | |
| if config.sd_mask_blur != 0: | |
| k = 2 * config.sd_mask_blur + 1 | |
| mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] | |
| return image, mask | |
| def forward_post_process(self, result, image, mask, config): | |
| if config.sd_match_histograms: | |
| result = self._match_histograms(result, image[:, :, ::-1], mask) | |
| if config.sd_mask_blur != 0: | |
| k = 2 * config.sd_mask_blur + 1 | |
| mask = cv2.GaussianBlur(mask, (k, k), 0) | |
| return result, image, mask | |