HelloTestUser commited on
Commit
ac6c809
·
verified ·
1 Parent(s): 831a2d9

Create app_local.py

Browse files
Files changed (1) hide show
  1. app_local.py +318 -0
app_local.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ # import spaces
7
+
8
+ from glob import glob
9
+ from typing import Tuple
10
+
11
+ from PIL import Image
12
+ # from gradio_imageslider import ImageSlider
13
+ import transformers
14
+ import torch
15
+ from torchvision import transforms
16
+
17
+ import requests
18
+ from io import BytesIO
19
+ import zipfile
20
+
21
+
22
+ torch.set_float32_matmul_precision('high')
23
+ # torch.jit.script = lambda f: f
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+
28
+ ## CPU version refinement
29
+ def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
30
+ if isinstance(image, Image.Image):
31
+ image = np.array(image) / 255.0
32
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
33
+
34
+ blurred_FGA = cv2.blur(FG * alpha, (r, r))
35
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
36
+
37
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
38
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
39
+ FG = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
40
+ FG = np.clip(FG, 0, 1)
41
+ return FG, blurred_B
42
+
43
+
44
+ def FB_blur_fusion_foreground_estimator_cpu_2(image, alpha, r=90):
45
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
46
+ alpha = alpha[:, :, None]
47
+ FG, blur_B = FB_blur_fusion_foreground_estimator_cpu(image, image, image, alpha, r)
48
+ return FB_blur_fusion_foreground_estimator_cpu(image, FG, blur_B, alpha, r=6)[0]
49
+
50
+
51
+ ## GPU version refinement
52
+ def mean_blur(x, kernel_size):
53
+ """
54
+ equivalent to cv.blur
55
+ x: [B, C, H, W]
56
+ """
57
+ if kernel_size % 2 == 0:
58
+ pad_l = kernel_size // 2 - 1
59
+ pad_r = kernel_size // 2
60
+ pad_t = kernel_size // 2 - 1
61
+ pad_b = kernel_size // 2
62
+ else:
63
+ pad_l = pad_r = pad_t = pad_b = kernel_size // 2
64
+
65
+ x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')
66
+
67
+ return torch.nn.functional.avg_pool2d(x_padded, kernel_size=(kernel_size, kernel_size), stride=1, count_include_pad=False)
68
+
69
+ def FB_blur_fusion_foreground_estimator_gpu(image, FG, B, alpha, r=90):
70
+ as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x
71
+
72
+ input_dtype = image.dtype
73
+ # convert image to float to avoid overflow
74
+ image = as_dtype(image, torch.float32)
75
+ FG = as_dtype(FG, torch.float32)
76
+ B = as_dtype(B, torch.float32)
77
+ alpha = as_dtype(alpha, torch.float32)
78
+
79
+ blurred_alpha = mean_blur(alpha, kernel_size=r)
80
+
81
+ blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
82
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
83
+
84
+ blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
85
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
86
+
87
+ FG_output = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
88
+ FG_output = torch.clamp(FG_output, 0, 1)
89
+
90
+ return as_dtype(FG_output, input_dtype), as_dtype(blurred_B, input_dtype)
91
+
92
+
93
+ def FB_blur_fusion_foreground_estimator_gpu_2(image, alpha, r=90):
94
+ # Thanks to the source: https://github.com/ZhengPeng7/BiRefNet/issues/226#issuecomment-3016433728
95
+ FG, blur_B = FB_blur_fusion_foreground_estimator_gpu(image, image, image, alpha, r)
96
+ return FB_blur_fusion_foreground_estimator_gpu(image, FG, blur_B, alpha, r=6)[0]
97
+
98
+
99
+ def refine_foreground(image, mask, r=90, device='cuda'):
100
+ """both image and mask are in range of [0, 1]"""
101
+ if mask.size != image.size:
102
+ mask = mask.resize(image.size)
103
+
104
+ if device == 'cuda':
105
+ image = transforms.functional.to_tensor(image).float().cuda()
106
+ mask = transforms.functional.to_tensor(mask).float().cuda()
107
+ image = image.unsqueeze(0)
108
+ mask = mask.unsqueeze(0)
109
+
110
+ estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
111
+
112
+ estimated_foreground = estimated_foreground.squeeze()
113
+ estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
114
+ estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
115
+ else:
116
+ image = np.array(image, dtype=np.float32) / 255.0
117
+ mask = np.array(mask, dtype=np.float32) / 255.0
118
+ estimated_foreground = FB_blur_fusion_foreground_estimator_cpu_2(image, mask, r=r)
119
+ estimated_foreground = (estimated_foreground * 255.0).astype(np.uint8)
120
+
121
+ estimated_foreground = Image.fromarray(np.ascontiguousarray(estimated_foreground))
122
+
123
+ return estimated_foreground
124
+
125
+
126
+ class ImagePreprocessor():
127
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
128
+ # Input resolution is on WxH.
129
+ self.transform_image = transforms.Compose([
130
+ transforms.Resize(resolution[::-1]),
131
+ transforms.ToTensor(),
132
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
133
+ ])
134
+
135
+ def proc(self, image: Image.Image) -> torch.Tensor:
136
+ image = self.transform_image(image)
137
+ return image
138
+
139
+
140
+ usage_to_weights_file = {
141
+ 'General': 'BiRefNet',
142
+ 'General-HR': 'BiRefNet_HR',
143
+ 'Matting-HR': 'BiRefNet_HR-matting',
144
+ 'Matting': 'BiRefNet-matting',
145
+ 'Portrait': 'BiRefNet-portrait',
146
+ 'General-reso_512': 'BiRefNet_512x512',
147
+ 'General-Lite': 'BiRefNet_lite',
148
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
149
+ 'Anime-Lite': 'BiRefNet_lite-Anime',
150
+ 'DIS': 'BiRefNet-DIS5K',
151
+ 'HRSOD': 'BiRefNet-HRSOD',
152
+ 'COD': 'BiRefNet-COD',
153
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
154
+ 'General-legacy': 'BiRefNet-legacy',
155
+ 'General-dynamic': 'BiRefNet_dynamic',
156
+ }
157
+
158
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
159
+ birefnet.to(device)
160
+ birefnet.eval(); birefnet.half()
161
+
162
+
163
+ # @spaces.GPU
164
+ def predict(images, resolution, weights_file):
165
+ assert (images is not None), 'AssertionError: images cannot be None.'
166
+
167
+ global birefnet
168
+ # Load BiRefNet with chosen weights
169
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
170
+ print('Using weights: {}.'.format(_weights_file))
171
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
172
+ birefnet.to(device)
173
+ birefnet.eval(); birefnet.half()
174
+
175
+ try:
176
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
177
+ except:
178
+ if weights_file in ['General-HR', 'Matting-HR']:
179
+ resolution = (2048, 2048)
180
+ elif weights_file in ['General-Lite-2K']:
181
+ resolution = (2560, 1440)
182
+ elif weights_file in ['General-reso_512']:
183
+ resolution = (512, 512)
184
+ else:
185
+ if weights_file in ['General-dynamic']:
186
+ resolution = None
187
+ print('Using the original size (div by 32) for inference.')
188
+ else:
189
+ resolution = (1024, 1024)
190
+ print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
191
+
192
+ if isinstance(images, list):
193
+ # For tab_batch
194
+ save_paths = []
195
+ save_dir = 'preds-BiRefNet'
196
+ if not os.path.exists(save_dir):
197
+ os.makedirs(save_dir)
198
+ tab_is_batch = True
199
+ else:
200
+ images = [images]
201
+ tab_is_batch = False
202
+
203
+ for idx_image, image_src in enumerate(images):
204
+ if isinstance(image_src, str):
205
+ if os.path.isfile(image_src):
206
+ image_ori = Image.open(image_src)
207
+ else:
208
+ response = requests.get(image_src)
209
+ image_data = BytesIO(response.content)
210
+ image_ori = Image.open(image_data)
211
+ else:
212
+ image_ori = Image.fromarray(image_src)
213
+
214
+ image = image_ori.convert('RGB')
215
+ # Preprocess the image
216
+ if resolution is None:
217
+ resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
218
+ if resolution_div_by_32 != resolution:
219
+ resolution = resolution_div_by_32
220
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
221
+ image_proc = image_preprocessor.proc(image)
222
+ image_proc = image_proc.unsqueeze(0)
223
+
224
+ # Prediction
225
+ with torch.no_grad():
226
+ preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
227
+ pred = preds[0].squeeze()
228
+
229
+ # Show Results
230
+ pred_pil = transforms.ToPILImage()(pred)
231
+ image_masked = refine_foreground(image, pred_pil, device=device)
232
+ image_masked.putalpha(pred_pil.resize(image.size))
233
+
234
+ torch.cuda.empty_cache()
235
+
236
+ if tab_is_batch:
237
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
238
+ image_masked.save(save_file_path)
239
+ save_paths.append(save_file_path)
240
+
241
+ if tab_is_batch:
242
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
243
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
244
+ for file in save_paths:
245
+ zipf.write(file, os.path.basename(file))
246
+ return save_paths, zip_file_path
247
+ else:
248
+ return (image_masked, image_ori)
249
+
250
+
251
+ examples = [[_] for _ in glob('examples/*')][:]
252
+ # Add the option of resolution in a text box.
253
+ for idx_example, example in enumerate(examples):
254
+ if 'My_' in example[0]:
255
+ example_resolution = '2048x2048'
256
+ model_choice = 'Matting-HR'
257
+ else:
258
+ example_resolution = '1024x1024'
259
+ model_choice = 'General'
260
+ examples[idx_example] = examples[idx_example] + [example_resolution, model_choice]
261
+
262
+ examples_url = [
263
+ ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
264
+ ]
265
+ for idx_example_url, example_url in enumerate(examples_url):
266
+ examples_url[idx_example_url] = examples_url[idx_example_url] + ['1024x1024', 'General']
267
+
268
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
269
+ ' The resolution used in our training was `1024x1024`, which is the suggested resolution to obtain good results! `2048x2048` is suggested for BiRefNet_HR.\n'
270
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
271
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
272
+
273
+ tab_image = gr.Interface(
274
+ fn=predict,
275
+ inputs=[
276
+ gr.Image(label='Upload an image'),
277
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
278
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
279
+ ],
280
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
281
+ examples=examples,
282
+ api_name="image",
283
+ description=descriptions,
284
+ )
285
+
286
+ tab_text = gr.Interface(
287
+ fn=predict,
288
+ inputs=[
289
+ gr.Textbox(label="Paste an image URL"),
290
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
291
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
292
+ ],
293
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
294
+ examples=examples_url,
295
+ api_name="URL",
296
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
297
+ )
298
+
299
+ tab_batch = gr.Interface(
300
+ fn=predict,
301
+ inputs=[
302
+ gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
303
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
304
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
305
+ ],
306
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
307
+ api_name="batch",
308
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
309
+ )
310
+
311
+ demo = gr.TabbedInterface(
312
+ [tab_image, tab_text, tab_batch],
313
+ ['image', 'URL', 'batch'],
314
+ title="Official Online Demo of BiRefNet",
315
+ )
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch(debug=True)