File size: 5,640 Bytes
194b4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Author: OedoSoldier [大江户战士]
# https://space.bilibili.com/55123

from PIL import Image, ImageFilter
import numpy as np


class CropUtils(object):
    """
    This class provides utility functions for cropping and restoring images.

    The `crop_img()` function takes an image and a corresponding mask, and uses the mask to
    crop the image to the minimum bounding box that includes the non-zero pixels in the mask.
    If the width and height of the resulting image are not equal, the image is scaled up to a
    square image using zero padding. The function returns the cropped image, the cropped mask,
    and the bounding box and image size as a tuple.

    The `restore_by_file()` function takes a raw image, a cropped image, a reference image,
    and a blur mask, and uses these images to restore the cropped image to the raw image.
    The reference image is used to determine the bounding box of the cropped image, and the
    blur mask is used to apply a gaussian blur to the alpha channel of the cropped image.
    The function returns the restored image.
    """

    def crop_img(self, img, mask, threshold=50):
        """
        Crop the given image using the given mask.

        Args:
            img: The image to be cropped, as a PIL.Image object.
            mask: The mask to be used for cropping, as a PIL.Image object.
            threshold: The threshold to use for converting the mask to binary. Pixels in the mask
                       with a value greater than the threshold will be considered as part of the
                       mask, and will be included in the cropped image. Pixels with a value less
                       than or equal to the threshold will be ignored. (default: 50)

        Returns:
            A tuple containing the cropped image, the cropped mask, and a tuple with the bounding
            box and image size. If the mask is empty, the function returns (img, None, None).
        """

        # Code for cropping the image using the mask

        mask = mask.resize(img.size) if img.size[0] != mask.size[0] else mask

        bbox = mask.convert('L').point(
            lambda x: 255 if x > threshold else 0,
            mode='1').getbbox()

        if bbox:
            img, mask = img.crop(bbox), mask.crop(bbox)
            size = img.size
            if size[0] != size[1]:
                bigside = size[0] if size[0] > size[1] else size[1]

                img_np = np.zeros((bigside, bigside, 4), dtype=np.uint8)
                mask_np = np.zeros((bigside, bigside, 4), dtype=np.uint8)

                offset = (
                    round(
                        (bigside - size[0]) / 2),
                    round(
                        (bigside - size[1]) / 2))

                img_np[offset[1]:offset[1] + size[1],
                       offset[0]:offset[0] + size[0]] = img
                mask_np[offset[1]:offset[1] + size[1],
                        offset[0]:offset[0] + size[0]] = mask

                img = Image.fromarray(img_np)
                mask = Image.fromarray(mask_np)

            return img, mask, bbox + size

        return img, None, None

    def restore_by_file(
            self,
            raw,
            img,
            ref_img,
            blur_mask,
            info,
            mask_blur=0.5):
        """
        Restore the given cropped image to the given raw image.

        Args:
            raw: The raw image, as a PIL.Image object.
            img: The cropped image, as a PIL.Image object.
            ref_img: The reference image, as a PIL.Image object. This image is used to determine
                     the bounding box of the cropped image.
            blur_mask: The blur mask, as a PIL.Image object. This mask is used to apply a gaussian
                       blur to the alpha channel of the cropped image.
            info: A tuple containing the bounding box of the cropped image. This tuple should have
                  the form (upper_left_x, upper_left_y, lower_right_x, lower_right_y).
            mask_blur: The sigma value to use for the gaussian blur. Higher values result in a
                       stronger blur. (default: 0.5)

        Returns:
            The restored image, as a PIL.Image object.
        """

        # Code for restoring the cropped image

        raw_size = raw.size
        ref_size = ref_img.size

        upper_left_x = info[0]
        upper_left_y = info[1]

        img = img.resize(ref_size).convert('RGBA')
        blur_mask = blur_mask.resize(ref_size).convert('RGBA')
        raw = raw.convert('RGBA')

        bbox = ref_img.split(
        )[-1].convert('L').point(lambda x: 255 if x > 0 else 0, mode='1').getbbox()
        bbox = list(bbox)
        w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]

        img = img.crop(bbox)
        blur_mask = blur_mask.crop(bbox)

        blur_img = np.zeros((raw_size[1], raw_size[0], 4), dtype=np.uint8)
        blur_img[upper_left_y:upper_left_y +
                 h, upper_left_x:upper_left_x +
                 w, :] = np.array(blur_mask)
        blur_img = Image.fromarray(blur_img, 'RGBA')
        blur_img = blur_img.filter(ImageFilter.GaussianBlur(mask_blur))

        new_img = np.zeros((raw_size[1], raw_size[0], 4), dtype=np.uint8)
        new_img[upper_left_y:upper_left_y +
                h, upper_left_x:upper_left_x +
                w, :] = np.array(img)
        new_img = Image.fromarray(new_img, 'RGBA')

        new_img = Image.alpha_composite(raw, new_img)
        new_img.putalpha(blur_img.split()[-1].convert('L'))
        new_img = Image.alpha_composite(raw, new_img)

        return new_img