| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| from typing import Optional |
|
|
|
|
| def shift_tensor(tensor, x): |
| shifted_tensor = torch.zeros_like(tensor) |
|
|
| if x > 0: |
| shifted_tensor[:, x:] = tensor[:, :-x] |
| elif x < 0: |
| shifted_tensor[:, :x] = tensor[:, -x:] |
| else: |
| shifted_tensor = tensor |
|
|
| return shifted_tensor |
|
|
|
|
| def create_mask(input_image_path, w=64, h=64): |
| img = ( |
| Image.open(input_image_path) |
| .resize((w, h), Image.Resampling.NEAREST) |
| .convert("L") |
| ) |
| img_array = np.array(img) |
| mask = np.where(img_array == 255, 1, 0) |
| mask_tensor = torch.tensor(mask).int() |
|
|
| return mask_tensor |
|
|
|
|
| def save_array_as_png(array, path): |
| if array.dtype != np.uint8: |
| array = (array * 255).clip(0, 255).astype(np.uint8) |
| image = Image.fromarray(array, "RGBA") |
| image.save(path) |
|
|
|
|
| def convert_to_mask_inpainting(image_array, mask_path): |
| if image_array.shape[2] != 4: |
| raise ValueError("输入数组必须是 RGBA 格式") |
| mask = np.ones(image_array.shape[:2], dtype=np.uint8) * 255 |
| alpha_channel = image_array[:, :, 3] |
| mask[alpha_channel != 0] = 0 |
| mask_image = Image.fromarray(mask, mode="L") |
| mask_image.save(mask_path) |
|
|
| return mask_image |
|
|
|
|
| |
| def composite_images(background_path: str, mask_path: str) -> Image.Image: |
| background = Image.open(background_path).convert("RGBA") |
| mask = Image.open(mask_path).convert("L") |
|
|
| if background.size != mask.size: |
| mask = mask.resize(background.size) |
|
|
| mask_array = np.array(mask) > 128 |
|
|
| if background.mode == "RGBA": |
| white_canvas = Image.new("RGBA", background.size, (255, 255, 255, 255)) |
| else: |
| white_canvas = Image.new("RGB", background.size, (255, 255, 255)) |
|
|
| composite = Image.composite(background, white_canvas, Image.fromarray(mask_array)) |
|
|
| return composite.convert("RGB") |
|
|
|
|
| def process_mask_array(mask_array: np.ndarray) -> Image.Image: |
| alpha = mask_array[..., 3] |
| gray_array = np.where(alpha > 0, 0, 255).astype(np.uint8) |
| mask_image = Image.fromarray(gray_array, mode="L") |
| return mask_image.convert("1") |
|
|
|
|
| def process_mask(mask: Image.Image) -> Image.Image: |
| if mask.mode != "L": |
| mask = mask.convert("L") |
| return mask.point(lambda x: 1 if x > 128 else 0, mode="1") |
|
|
|
|
| def merge_masks(mask1: Image.Image, mask2: Image.Image) -> Image.Image: |
| arr1 = np.array(mask1, dtype=bool) |
| arr2 = np.array(mask2, dtype=bool) |
| merged = np.logical_and(arr1, arr2) |
| return Image.fromarray(merged).convert("1") |
|
|
|
|
| def save_merged_mask( |
| mask_array: np.ndarray, mask: Optional[Image.Image], output_path: str |
| ) -> None: |
| mask1 = process_mask_array(mask_array) |
|
|
| if mask is not None: |
| mask2 = process_mask(mask) |
| if mask1.size != mask2.size: |
| mask2 = mask2.resize(mask1.size, Image.NEAREST) |
| merged = merge_masks(mask1, mask2) |
| else: |
| merged = mask1 |
|
|
| merged.save(output_path) |
|
|