| import torch |
| from PIL import Image |
| from diffusers import StableDiffusionInpaintPipeline |
| import numpy as np |
|
|
| pipe_cache = {} |
|
|
| def outpaint_image(image: Image.Image, prompt: str, directions: list, model_id: str) -> Image.Image: |
| expand_top = 128 if "上" in directions else 0 |
| expand_bottom = 128 if "下" in directions else 0 |
| expand_left = 128 if "左" in directions else 0 |
| expand_right = 128 if "右" in directions else 0 |
|
|
| width, height = image.size |
| new_width = width + expand_left + expand_right |
| new_height = height + expand_top + expand_bottom |
|
|
| canvas = Image.new("RGB", (new_width, new_height), (255, 255, 255)) |
| canvas.paste(image, (expand_left, expand_top)) |
|
|
| mask = Image.new("L", (new_width, new_height), 255) |
|
|
| if model_id not in pipe_cache: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
| pipe_cache[model_id] = StableDiffusionInpaintPipeline.from_pretrained( |
| model_id, |
| torch_dtype=dtype |
| ).to(device) |
|
|
| pipe = pipe_cache[model_id] |
| result = pipe(prompt=prompt, image=canvas, mask_image=mask).images[0] |
| return result |
|
|