GulbaharAI commited on
Commit
dd1d7d4
·
verified ·
1 Parent(s): 952da02

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +105 -0
utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from scipy.ndimage import convolve, zoom
4
+ from PIL import Image
5
+
6
+
7
+
8
+ def pad_to_multiple(image: np.ndarray, multiple: int = 8):
9
+ h, w = image.shape[:2]
10
+ pad_h = (multiple - h % multiple) % multiple
11
+ pad_w = (multiple - w % multiple) % multiple
12
+ if image.ndim == 3:
13
+ padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
14
+ else:
15
+ padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
16
+ return padded, h, w
17
+
18
+ def crop_to_original(image: np.ndarray, h: int, w: int):
19
+ return image[:h, :w]
20
+
21
+ def wavelet_blur_np(image: np.ndarray, radius: int):
22
+ kernel = np.array([
23
+ [0.0625, 0.125, 0.0625],
24
+ [0.125, 0.25, 0.125],
25
+ [0.0625, 0.125, 0.0625]
26
+ ], dtype=np.float32)
27
+
28
+ blurred = np.empty_like(image)
29
+ for c in range(image.shape[0]):
30
+ blurred_c = convolve(image[c], kernel, mode='nearest')
31
+ if radius > 1:
32
+ blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
33
+ blurred[c] = blurred_c
34
+ return blurred
35
+
36
+ def wavelet_decomposition_np(image: np.ndarray, levels=5):
37
+ high_freq = np.zeros_like(image)
38
+ for i in range(levels):
39
+ radius = 2 ** i
40
+ low_freq = wavelet_blur_np(image, radius)
41
+ high_freq += (image - low_freq)
42
+ image = low_freq
43
+ return high_freq, low_freq
44
+
45
+ def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
46
+ content_high, _ = wavelet_decomposition_np(content_feat)
47
+ _, style_low = wavelet_decomposition_np(style_feat)
48
+ return content_high + style_low
49
+
50
+ def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
51
+ fused_np = fused.astype(np.float32) / 255.0
52
+ mask_np = mask.astype(np.float32) / 255.0
53
+
54
+ fused_np = fused_np.transpose(2, 0, 1)
55
+ mask_np = mask_np.transpose(2, 0, 1)
56
+
57
+ result_np = wavelet_reconstruction_np(fused_np, mask_np)
58
+
59
+ result_np = result_np.transpose(1, 2, 0)
60
+ result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
61
+
62
+ return result_np
63
+
64
+ def attention_guided_fusion(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
65
+ H, W = ori.shape[:2]
66
+ attn_map = attn_map.astype(np.float32)
67
+ _, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
68
+ am = attn_map.astype(np.float32)
69
+ am = am/255.0
70
+ am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
71
+
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
73
+ am_d = cv2.dilate(am_up, kernel, iterations=1)
74
+ am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
75
+
76
+ am_merged = np.maximum(am_up, am_d)
77
+ am_merged = np.clip(am_merged, 0, 1)
78
+
79
+ attn_up_3c = np.stack([am_merged]*3, axis=-1)
80
+ attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
81
+
82
+ ori_out = ori * (1 - attn_up_ori_3c)
83
+ rem_out = removed * (1 - attn_up_ori_3c)
84
+
85
+ ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
86
+ rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
87
+
88
+ wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
89
+ wave = crop_to_original(wave_rgb, h0, w0)
90
+ # fusion
91
+ fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
92
+ return fused
93
+
94
+
95
+ def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
96
+ w, h = image.size
97
+ if w < h:
98
+ new_w = target_short
99
+ new_h = int(h * target_short / w)
100
+ new_h = (new_h + 15) // 16 * 16
101
+ else:
102
+ new_h = target_short
103
+ new_w = int(w * target_short / h)
104
+ new_w = (new_w + 15) // 16 * 16
105
+ return image.resize((new_w, new_h), resample=resample)