File size: 3,033 Bytes
70194df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
from PIL import Image
import cv2
from typing import Optional


def shift_tensor(tensor, x):
    shifted_tensor = torch.zeros_like(tensor)

    if x > 0:
        shifted_tensor[:, x:] = tensor[:, :-x]
    elif x < 0:
        shifted_tensor[:, :x] = tensor[:, -x:]
    else:
        shifted_tensor = tensor  # No shift for x == 0

    return shifted_tensor


def create_mask(input_image_path, w=64, h=64):
    img = (
        Image.open(input_image_path)
        .resize((w, h), Image.Resampling.NEAREST)
        .convert("L")
    )
    img_array = np.array(img)
    mask = np.where(img_array == 255, 1, 0)
    mask_tensor = torch.tensor(mask).int()

    return mask_tensor


def save_array_as_png(array, path):
    if array.dtype != np.uint8:
        array = (array * 255).clip(0, 255).astype(np.uint8)
    image = Image.fromarray(array, "RGBA")
    image.save(path)


def convert_to_mask_inpainting(image_array, mask_path):
    if image_array.shape[2] != 4:
        raise ValueError("输入数组必须是 RGBA 格式")
    mask = np.ones(image_array.shape[:2], dtype=np.uint8) * 255
    alpha_channel = image_array[:, :, 3]
    mask[alpha_channel != 0] = 0
    mask_image = Image.fromarray(mask, mode="L")
    mask_image.save(mask_path)

    return mask_image


# mask for Subject Customiztion
def composite_images(background_path: str, mask_path: str) -> Image.Image:
    background = Image.open(background_path).convert("RGBA")
    mask = Image.open(mask_path).convert("L")

    if background.size != mask.size:
        mask = mask.resize(background.size)

    mask_array = np.array(mask) > 128

    if background.mode == "RGBA":
        white_canvas = Image.new("RGBA", background.size, (255, 255, 255, 255))
    else:
        white_canvas = Image.new("RGB", background.size, (255, 255, 255))

    composite = Image.composite(background, white_canvas, Image.fromarray(mask_array))

    return composite.convert("RGB")


def process_mask_array(mask_array: np.ndarray) -> Image.Image:
    alpha = mask_array[..., 3]
    gray_array = np.where(alpha > 0, 0, 255).astype(np.uint8)
    mask_image = Image.fromarray(gray_array, mode="L")
    return mask_image.convert("1")


def process_mask(mask: Image.Image) -> Image.Image:
    if mask.mode != "L":
        mask = mask.convert("L")
    return mask.point(lambda x: 1 if x > 128 else 0, mode="1")


def merge_masks(mask1: Image.Image, mask2: Image.Image) -> Image.Image:
    arr1 = np.array(mask1, dtype=bool)
    arr2 = np.array(mask2, dtype=bool)
    merged = np.logical_and(arr1, arr2)
    return Image.fromarray(merged).convert("1")


def save_merged_mask(
    mask_array: np.ndarray, mask: Optional[Image.Image], output_path: str
) -> None:
    mask1 = process_mask_array(mask_array)

    if mask is not None:
        mask2 = process_mask(mask)
        if mask1.size != mask2.size:
            mask2 = mask2.resize(mask1.size, Image.NEAREST)
        merged = merge_masks(mask1, mask2)
    else:
        merged = mask1

    merged.save(output_path)