tuotuotu / utils.py
baobao7758520's picture
Upload 4 files
3c5c3c1 verified
import torch
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
import numpy as np
pipe_cache = {}
def outpaint_image(image: Image.Image, prompt: str, directions: list, model_id: str) -> Image.Image:
expand_top = 128 if "上" in directions else 0
expand_bottom = 128 if "下" in directions else 0
expand_left = 128 if "左" in directions else 0
expand_right = 128 if "右" in directions else 0
width, height = image.size
new_width = width + expand_left + expand_right
new_height = height + expand_top + expand_bottom
canvas = Image.new("RGB", (new_width, new_height), (255, 255, 255))
canvas.paste(image, (expand_left, expand_top))
mask = Image.new("L", (new_width, new_height), 255)
if model_id not in pipe_cache:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe_cache[model_id] = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=dtype
).to(device)
pipe = pipe_cache[model_id]
result = pipe(prompt=prompt, image=canvas, mask_image=mask).images[0]
return result