|
|
from typing import List, Union |
|
|
|
|
|
import torch |
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
from util.commons import disable_safety_checker, download_image |
|
|
|
|
|
|
|
|
class InPainter: |
|
|
def load(self): |
|
|
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-inpainting", |
|
|
torch_dtype=torch.float16, |
|
|
revision="fp16", |
|
|
).to("cuda") |
|
|
disable_safety_checker(self.pipe) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def process( |
|
|
self, |
|
|
image_url: str, |
|
|
mask_image_url: str, |
|
|
width: int, |
|
|
height: int, |
|
|
seed: int, |
|
|
prompt: Union[str, List[str]], |
|
|
negative_prompt: Union[str, List[str]], |
|
|
): |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
input_img = download_image(image_url).resize((width, height)) |
|
|
mask_img = download_image(mask_image_url).resize((width, height)) |
|
|
|
|
|
return self.pipe.__call__( |
|
|
prompt=prompt, |
|
|
image=input_img, |
|
|
mask_image=mask_img, |
|
|
height=height, |
|
|
width=width, |
|
|
negative_prompt=negative_prompt, |
|
|
).images |
|
|
|