DiT360_edit / pa_src /utils.py
asd755's picture
Upload 3 files
70194df verified
import torch
import numpy as np
from PIL import Image
import cv2
from typing import Optional
def shift_tensor(tensor, x):
shifted_tensor = torch.zeros_like(tensor)
if x > 0:
shifted_tensor[:, x:] = tensor[:, :-x]
elif x < 0:
shifted_tensor[:, :x] = tensor[:, -x:]
else:
shifted_tensor = tensor # No shift for x == 0
return shifted_tensor
def create_mask(input_image_path, w=64, h=64):
img = (
Image.open(input_image_path)
.resize((w, h), Image.Resampling.NEAREST)
.convert("L")
)
img_array = np.array(img)
mask = np.where(img_array == 255, 1, 0)
mask_tensor = torch.tensor(mask).int()
return mask_tensor
def save_array_as_png(array, path):
if array.dtype != np.uint8:
array = (array * 255).clip(0, 255).astype(np.uint8)
image = Image.fromarray(array, "RGBA")
image.save(path)
def convert_to_mask_inpainting(image_array, mask_path):
if image_array.shape[2] != 4:
raise ValueError("输入数组必须是 RGBA 格式")
mask = np.ones(image_array.shape[:2], dtype=np.uint8) * 255
alpha_channel = image_array[:, :, 3]
mask[alpha_channel != 0] = 0
mask_image = Image.fromarray(mask, mode="L")
mask_image.save(mask_path)
return mask_image
# mask for Subject Customiztion
def composite_images(background_path: str, mask_path: str) -> Image.Image:
background = Image.open(background_path).convert("RGBA")
mask = Image.open(mask_path).convert("L")
if background.size != mask.size:
mask = mask.resize(background.size)
mask_array = np.array(mask) > 128
if background.mode == "RGBA":
white_canvas = Image.new("RGBA", background.size, (255, 255, 255, 255))
else:
white_canvas = Image.new("RGB", background.size, (255, 255, 255))
composite = Image.composite(background, white_canvas, Image.fromarray(mask_array))
return composite.convert("RGB")
def process_mask_array(mask_array: np.ndarray) -> Image.Image:
alpha = mask_array[..., 3]
gray_array = np.where(alpha > 0, 0, 255).astype(np.uint8)
mask_image = Image.fromarray(gray_array, mode="L")
return mask_image.convert("1")
def process_mask(mask: Image.Image) -> Image.Image:
if mask.mode != "L":
mask = mask.convert("L")
return mask.point(lambda x: 1 if x > 128 else 0, mode="1")
def merge_masks(mask1: Image.Image, mask2: Image.Image) -> Image.Image:
arr1 = np.array(mask1, dtype=bool)
arr2 = np.array(mask2, dtype=bool)
merged = np.logical_and(arr1, arr2)
return Image.fromarray(merged).convert("1")
def save_merged_mask(
mask_array: np.ndarray, mask: Optional[Image.Image], output_path: str
) -> None:
mask1 = process_mask_array(mask_array)
if mask is not None:
mask2 = process_mask(mask)
if mask1.size != mask2.size:
mask2 = mask2.resize(mask1.size, Image.NEAREST)
merged = merge_masks(mask1, mask2)
else:
merged = mask1
merged.save(output_path)