| | from __future__ import annotations |
| |
|
| | import inspect |
| | from abc import ABC, abstractmethod |
| | from typing import Any, Callable, Iterable, List, Mapping, Optional |
| |
|
| | from diffusers.utils import logging |
| | from PIL import Image |
| |
|
| | from asdff.utils import ( |
| | ADOutput, |
| | bbox_padding, |
| | composite, |
| | mask_dilate, |
| | mask_gaussian_blur, |
| | ) |
| | from asdff.yolo import yolo_detector |
| |
|
| | logger = logging.get_logger("diffusers") |
| |
|
| |
|
| | DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]] |
| |
|
| |
|
| | def ordinal(n: int) -> str: |
| | d = {1: "st", 2: "nd", 3: "rd"} |
| | return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th")) |
| |
|
| |
|
| | class AdPipelineBase(ABC): |
| | @property |
| | @abstractmethod |
| | def inpaint_pipeline(self) -> Callable: |
| | raise NotImplementedError |
| |
|
| | @property |
| | @abstractmethod |
| | def txt2img_class(self) -> type: |
| | raise NotImplementedError |
| |
|
| | def __call__( |
| | self, |
| | common: Mapping[str, Any] | None = None, |
| | txt2img_only: Mapping[str, Any] | None = None, |
| | inpaint_only: Mapping[str, Any] | None = None, |
| | images: Image.Image | Iterable[Image.Image] | None = None, |
| | detectors: DetectorType | Iterable[DetectorType] | None = None, |
| | mask_dilation: int = 4, |
| | mask_blur: int = 4, |
| | mask_padding: int = 32, |
| | ): |
| | if common is None: |
| | common = {} |
| | if txt2img_only is None: |
| | txt2img_only = {} |
| | if inpaint_only is None: |
| | inpaint_only = {} |
| | if "strength" not in inpaint_only: |
| | inpaint_only = {**inpaint_only, "strength": 0.4} |
| |
|
| | if detectors is None: |
| | detectors = [self.default_detector] |
| | elif not isinstance(detectors, Iterable): |
| | detectors = [detectors] |
| |
|
| | if images is None: |
| | txt2img_output = self.process_txt2img(common, txt2img_only) |
| | txt2img_images = txt2img_output[0] |
| | else: |
| | if txt2img_only: |
| | msg = "Both `images` and `txt2img_only` are specified. if `images` is specified, `txt2img_only` is ignored." |
| | logger.warning(msg) |
| |
|
| | txt2img_images = [images] if not isinstance(images, Iterable) else images |
| |
|
| | init_images = [] |
| | final_images = [] |
| |
|
| | for i, init_image in enumerate(txt2img_images): |
| | init_images.append(init_image.copy()) |
| | final_image = None |
| |
|
| | for j, detector in enumerate(detectors): |
| | masks = detector(init_image) |
| | if masks is None: |
| | logger.info( |
| | f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector." |
| | ) |
| | continue |
| |
|
| | for k, mask in enumerate(masks): |
| | mask = mask.convert("L") |
| | mask = mask_dilate(mask, mask_dilation) |
| | bbox = mask.getbbox() |
| | if bbox is None: |
| | logger.info(f"No object in {ordinal(k + 1)} mask.") |
| | continue |
| | mask = mask_gaussian_blur(mask, mask_blur) |
| | bbox_padded = bbox_padding(bbox, init_image.size, mask_padding) |
| |
|
| | inpaint_output = self.process_inpainting( |
| | common, |
| | inpaint_only, |
| | init_image, |
| | mask, |
| | bbox_padded, |
| | ) |
| | inpaint_image = inpaint_output[0][0] |
| |
|
| | final_image = composite( |
| | init_image, |
| | mask, |
| | inpaint_image, |
| | bbox_padded, |
| | ) |
| | init_image = final_image |
| |
|
| | if final_image is not None: |
| | final_images.append(final_image) |
| |
|
| | return ADOutput(images=final_images, init_images=init_images) |
| |
|
| | @property |
| | def default_detector(self) -> Callable[..., list[Image.Image] | None]: |
| | return yolo_detector |
| |
|
| | def _get_txt2img_args( |
| | self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any] |
| | ): |
| | return {**common, **txt2img_only, "output_type": "pil"} |
| |
|
| | def _get_inpaint_args( |
| | self, common: Mapping[str, Any], inpaint_only: Mapping[str, Any] |
| | ): |
| | common = dict(common) |
| | sig = inspect.signature(self.inpaint_pipeline) |
| | if ( |
| | "control_image" in sig.parameters |
| | and "control_image" not in common |
| | and "image" in common |
| | ): |
| | common["control_image"] = common.pop("image") |
| | return { |
| | **common, |
| | **inpaint_only, |
| | "num_images_per_prompt": 1, |
| | "output_type": "pil", |
| | } |
| |
|
| | def process_txt2img( |
| | self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any] |
| | ): |
| | txt2img_args = self._get_txt2img_args(common, txt2img_only) |
| | return self.txt2img_class.__call__(self, **txt2img_args) |
| |
|
| | def process_inpainting( |
| | self, |
| | common: Mapping[str, Any], |
| | inpaint_only: Mapping[str, Any], |
| | init_image: Image.Image, |
| | mask: Image.Image, |
| | bbox_padded: tuple[int, int, int, int], |
| | ): |
| | crop_image = init_image.crop(bbox_padded) |
| | crop_mask = mask.crop(bbox_padded) |
| | inpaint_args = self._get_inpaint_args(common, inpaint_only) |
| | inpaint_args["image"] = crop_image |
| | inpaint_args["mask_image"] = crop_mask |
| |
|
| | if "control_image" in inpaint_args: |
| | inpaint_args["control_image"] = inpaint_args["control_image"].resize( |
| | crop_image.size |
| | ) |
| | return self.inpaint_pipeline(**inpaint_args) |
| |
|