| | import argparse |
| | import os |
| |
|
| | import torch |
| | from PIL import Image, ImageFilter |
| | from transformers import CLIPTextModel |
| |
|
| | from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel |
| |
|
| |
|
| | parser = argparse.ArgumentParser(description="Inference") |
| | parser.add_argument( |
| | "--model_path", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Path to pretrained model or model identifier from huggingface.co/models.", |
| | ) |
| | parser.add_argument( |
| | "--validation_image", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="The directory of the validation image", |
| | ) |
| | parser.add_argument( |
| | "--validation_mask", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="The directory of the validation mask", |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="./test-infer/", |
| | help="The output directory where predictions are saved", |
| | ) |
| | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.") |
| |
|
| | args = parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | generator = None |
| |
|
| | |
| | pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None |
| | ) |
| |
|
| | pipe.unet = UNet2DConditionModel.from_pretrained( |
| | args.model_path, |
| | subfolder="unet", |
| | revision=None, |
| | ) |
| | pipe.text_encoder = CLIPTextModel.from_pretrained( |
| | args.model_path, |
| | subfolder="text_encoder", |
| | revision=None, |
| | ) |
| | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
| | pipe = pipe.to("cuda") |
| |
|
| | if args.seed is not None: |
| | generator = torch.Generator(device="cuda").manual_seed(args.seed) |
| |
|
| | image = Image.open(args.validation_image) |
| | mask_image = Image.open(args.validation_mask) |
| |
|
| | results = pipe( |
| | ["a photo of sks"] * 16, |
| | image=image, |
| | mask_image=mask_image, |
| | num_inference_steps=25, |
| | guidance_scale=5, |
| | generator=generator, |
| | ).images |
| |
|
| | erode_kernel = ImageFilter.MaxFilter(3) |
| | mask_image = mask_image.filter(erode_kernel) |
| |
|
| | blur_kernel = ImageFilter.BoxBlur(1) |
| | mask_image = mask_image.filter(blur_kernel) |
| |
|
| | for idx, result in enumerate(results): |
| | result = Image.composite(result, image, mask_image) |
| | result.save(f"{args.output_dir}/{idx}.png") |
| |
|
| | del pipe |
| | torch.cuda.empty_cache() |
| |
|