| from typing import List, Union |
|
|
| import torch |
| from diffusers import ( |
| StableDiffusionInpaintPipeline, |
| StableDiffusionXLInpaintPipeline, |
| UNet2DConditionModel, |
| ) |
|
|
| from internals.pipelines.commons import AbstractPipeline |
| from internals.pipelines.high_res import HighRes |
| from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor |
| from internals.util import get_generators |
| from internals.util.cache import clear_cuda_and_gc |
| from internals.util.commons import disable_safety_checker, download_image |
| from internals.util.config import ( |
| get_base_inpaint_model_revision, |
| get_base_inpaint_model_variant, |
| get_hf_cache_dir, |
| get_hf_token, |
| get_inpaint_model_path, |
| get_is_sdxl, |
| get_model_dir, |
| get_num_return_sequences, |
| ) |
|
|
|
|
| class InPainter(AbstractPipeline): |
| __loaded = False |
|
|
| def init(self, pipeline: AbstractPipeline): |
| self.__base = pipeline |
|
|
| def load(self): |
| if self.__loaded: |
| return |
|
|
| if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir(): |
| self.create(self.__base) |
| self.__loaded = True |
| return |
|
|
| if get_is_sdxl(): |
| |
| unet = UNet2DConditionModel.from_pretrained( |
| get_inpaint_model_path(), |
| torch_dtype=torch.float16, |
| cache_dir=get_hf_cache_dir(), |
| token=get_hf_token(), |
| subfolder="unet", |
| variant=get_base_inpaint_model_variant(), |
| revision=get_base_inpaint_model_revision(), |
| ).to("cuda") |
| kwargs = {**self.__base.pipe.components, "unet": unet} |
| self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda") |
| self.pipe.mask_processor = VaeImageProcessor( |
| vae_scale_factor=self.pipe.vae_scale_factor, |
| do_normalize=False, |
| do_binarize=True, |
| do_convert_grayscale=True, |
| ) |
| self.pipe.image_processor = VaeImageProcessor( |
| vae_scale_factor=self.pipe.vae_scale_factor |
| ) |
| else: |
| self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| get_inpaint_model_path(), |
| torch_dtype=torch.float16, |
| cache_dir=get_hf_cache_dir(), |
| token=get_hf_token(), |
| ).to("cuda") |
|
|
| disable_safety_checker(self.pipe) |
|
|
| self.__patch() |
|
|
| self.__loaded = True |
|
|
| def create(self, pipeline: AbstractPipeline): |
| if get_is_sdxl(): |
| self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to( |
| "cuda" |
| ) |
| else: |
| self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( |
| "cuda" |
| ) |
| disable_safety_checker(self.pipe) |
|
|
| self.__patch() |
|
|
| def __patch(self): |
| if get_is_sdxl(): |
| self.pipe.enable_vae_tiling() |
| self.pipe.enable_vae_slicing() |
| self.pipe.enable_xformers_memory_efficient_attention() |
|
|
| def unload(self): |
| self.__loaded = False |
| self.pipe = None |
| clear_cuda_and_gc() |
|
|
| @torch.inference_mode() |
| def process( |
| self, |
| image_url: str, |
| mask_image_url: str, |
| width: int, |
| height: int, |
| seed: int, |
| prompt: Union[str, List[str]], |
| negative_prompt: Union[str, List[str]], |
| num_inference_steps: int, |
| **kwargs, |
| ): |
| generator = get_generators(seed, get_num_return_sequences()) |
|
|
| input_img = download_image(image_url).resize((width, height)) |
| mask_img = download_image(mask_image_url).resize((width, height)) |
|
|
| if get_is_sdxl(): |
| width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height) |
| mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33) |
|
|
| kwargs["strength"] = 0.999 |
| kwargs["padding_mask_crop"] = 1000 |
|
|
| kwargs = { |
| "prompt": prompt, |
| "image": input_img, |
| "mask_image": mask_img, |
| "height": height, |
| "width": width, |
| "negative_prompt": negative_prompt, |
| "num_inference_steps": num_inference_steps, |
| "strength": 1.0, |
| "generator": generator, |
| **kwargs, |
| } |
| return self.pipe.__call__(**kwargs).images, mask_img |
|
|