AItool commited on
Commit
8ec9d64
·
verified ·
1 Parent(s): 8448b12

Delete croper.py

Browse files
Files changed (1) hide show
  1. croper.py +0 -108
croper.py DELETED
@@ -1,108 +0,0 @@
1
- import PIL
2
- import numpy as np
3
-
4
- from PIL import Image
5
-
6
- class Croper:
7
- def __init__(
8
- self,
9
- input_image: PIL.Image,
10
- target_mask: np.ndarray,
11
- mask_size: int = 256,
12
- mask_expansion: int = 20,
13
- ):
14
- self.input_image = input_image
15
- self.target_mask = target_mask
16
- self.mask_size = mask_size
17
- self.mask_expansion = mask_expansion
18
-
19
- def corp_mask_image(self):
20
- target_mask = self.target_mask
21
- input_image = self.input_image
22
- mask_expansion = self.mask_expansion
23
- original_width, original_height = input_image.size
24
- mask_indices = np.where(target_mask)
25
- start_y = np.min(mask_indices[0])
26
- end_y = np.max(mask_indices[0])
27
- start_x = np.min(mask_indices[1])
28
- end_x = np.max(mask_indices[1])
29
- mask_height = end_y - start_y
30
- mask_width = end_x - start_x
31
- # choose the max side length
32
- max_side_length = max(mask_height, mask_width)
33
- # expand the mask area
34
- height_diff = (max_side_length - mask_height) // 2
35
- width_diff = (max_side_length - mask_width) // 2
36
- start_y = start_y - mask_expansion - height_diff
37
- if start_y < 0:
38
- start_y = 0
39
- end_y = end_y + mask_expansion + height_diff
40
- if end_y > original_height:
41
- end_y = original_height
42
- start_x = start_x - mask_expansion - width_diff
43
- if start_x < 0:
44
- start_x = 0
45
- end_x = end_x + mask_expansion + width_diff
46
- if end_x > original_width:
47
- end_x = original_width
48
- expanded_height = end_y - start_y
49
- expanded_width = end_x - start_x
50
- expanded_max_side_length = max(expanded_height, expanded_width)
51
- # calculate the crop area
52
- crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
- crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
- crop_mask_end_y = crop_mask_start_y + expanded_height
55
- crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
- crop_mask_end_x = crop_mask_start_x + expanded_width
57
- # create a square mask
58
- square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
- square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
- square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
-
62
- crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
- square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
- square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
-
66
- self.origin_start_x = start_x
67
- self.origin_start_y = start_y
68
- self.origin_end_x = end_x
69
- self.origin_end_y = end_y
70
-
71
- self.square_start_x = crop_mask_start_x
72
- self.square_start_y = crop_mask_start_y
73
- self.square_end_x = crop_mask_end_x
74
- self.square_end_y = crop_mask_end_y
75
-
76
- self.square_length = expanded_max_side_length
77
- self.square_mask_image = square_mask_image
78
- self.square_image = square_image
79
- self.corp_mask = crop_mask
80
-
81
- mask_size = self.mask_size
82
- self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
- self.resized_square_image = square_image.resize((mask_size, mask_size))
84
-
85
- return self.resized_square_mask_image
86
-
87
- def restore_result(self, generated_image):
88
- square_length = self.square_length
89
- generated_image = generated_image.resize((square_length, square_length))
90
- square_mask_image = self.square_mask_image
91
- cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
- cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
-
94
- restored_image = self.input_image.copy()
95
- restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
-
97
- return restored_image
98
-
99
- def restore_result_v2(self, generated_image):
100
- square_length = self.square_length
101
- generated_image = generated_image.resize((square_length, square_length))
102
- cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
-
104
- restored_image = self.input_image.copy()
105
- restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
-
107
- return restored_image
108
-