HelloTestUser commited on
Commit
831a2d9
·
verified ·
1 Parent(s): 9f70a5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ from glob import glob
9
+ from typing import Tuple
10
+
11
+ from PIL import Image
12
+ import torch
13
+ from torchvision import transforms
14
+
15
+ import requests
16
+ from io import BytesIO
17
+ import zipfile
18
+
19
+ # Fix the HF space permission error when using from_pretrained(..., trust_remote_code=True)
20
+ os.environ["HF_MODULES_CACHE"] = os.path.join("/tmp/hf_cache", "modules")
21
+
22
+ import transformers
23
+ transformers.utils.move_cache()
24
+
25
+
26
+ torch.set_float32_matmul_precision('high')
27
+ torch.jit.script = lambda f: f
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+
32
+ ## CPU version refinement
33
+ def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
34
+ if isinstance(image, Image.Image):
35
+ image = np.array(image) / 255.0
36
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
37
+
38
+ blurred_FGA = cv2.blur(FG * alpha, (r, r))
39
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
40
+
41
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
42
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
43
+ FG = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
44
+ FG = np.clip(FG, 0, 1)
45
+ return FG, blurred_B
46
+
47
+
48
+ def FB_blur_fusion_foreground_estimator_cpu_2(image, alpha, r=90):
49
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
50
+ alpha = alpha[:, :, None]
51
+ FG, blur_B = FB_blur_fusion_foreground_estimator_cpu(image, image, image, alpha, r)
52
+ return FB_blur_fusion_foreground_estimator_cpu(image, FG, blur_B, alpha, r=6)[0]
53
+
54
+
55
+ ## GPU version refinement
56
+ def mean_blur(x, kernel_size):
57
+ """
58
+ equivalent to cv.blur
59
+ x: [B, C, H, W]
60
+ """
61
+ if kernel_size % 2 == 0:
62
+ pad_l = kernel_size // 2 - 1
63
+ pad_r = kernel_size // 2
64
+ pad_t = kernel_size // 2 - 1
65
+ pad_b = kernel_size // 2
66
+ else:
67
+ pad_l = pad_r = pad_t = pad_b = kernel_size // 2
68
+
69
+ x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')
70
+
71
+ return torch.nn.functional.avg_pool2d(x_padded, kernel_size=(kernel_size, kernel_size), stride=1, count_include_pad=False)
72
+
73
+ def FB_blur_fusion_foreground_estimator_gpu(image, FG, B, alpha, r=90):
74
+ as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x
75
+
76
+ input_dtype = image.dtype
77
+ # convert image to float to avoid overflow
78
+ image = as_dtype(image, torch.float32)
79
+ FG = as_dtype(FG, torch.float32)
80
+ B = as_dtype(B, torch.float32)
81
+ alpha = as_dtype(alpha, torch.float32)
82
+
83
+ blurred_alpha = mean_blur(alpha, kernel_size=r)
84
+
85
+ blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
86
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
87
+
88
+ blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
89
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
90
+
91
+ FG_output = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
92
+ FG_output = torch.clamp(FG_output, 0, 1)
93
+
94
+ return as_dtype(FG_output, input_dtype), as_dtype(blurred_B, input_dtype)
95
+
96
+
97
+ def FB_blur_fusion_foreground_estimator_gpu_2(image, alpha, r=90):
98
+ # Thanks to the source: https://github.com/ZhengPeng7/BiRefNet/issues/226#issuecomment-3016433728
99
+ FG, blur_B = FB_blur_fusion_foreground_estimator_gpu(image, image, image, alpha, r)
100
+ return FB_blur_fusion_foreground_estimator_gpu(image, FG, blur_B, alpha, r=6)[0]
101
+
102
+
103
+ def refine_foreground(image, mask, r=90, device='cuda'):
104
+ """both image and mask are in range of [0, 1]"""
105
+ if mask.size != image.size:
106
+ mask = mask.resize(image.size)
107
+
108
+ if device == 'cuda':
109
+ image = transforms.functional.to_tensor(image).float().cuda()
110
+ mask = transforms.functional.to_tensor(mask).float().cuda()
111
+ image = image.unsqueeze(0)
112
+ mask = mask.unsqueeze(0)
113
+
114
+ estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
115
+
116
+ estimated_foreground = estimated_foreground.squeeze()
117
+ estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
118
+ estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
119
+ else:
120
+ image = np.array(image, dtype=np.float32) / 255.0
121
+ mask = np.array(mask, dtype=np.float32) / 255.0
122
+ estimated_foreground = FB_blur_fusion_foreground_estimator_cpu_2(image, mask, r=r)
123
+ estimated_foreground = (estimated_foreground * 255.0).astype(np.uint8)
124
+
125
+ estimated_foreground = Image.fromarray(np.ascontiguousarray(estimated_foreground))
126
+
127
+ return estimated_foreground
128
+
129
+
130
+ class ImagePreprocessor():
131
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
132
+ # Input resolution is on WxH.
133
+ self.transform_image = transforms.Compose([
134
+ transforms.Resize(resolution[::-1]),
135
+ transforms.ToTensor(),
136
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
137
+ ])
138
+
139
+ def proc(self, image: Image.Image) -> torch.Tensor:
140
+ image = self.transform_image(image)
141
+ return image
142
+
143
+
144
+ usage_to_weights_file = {
145
+ 'General': 'BiRefNet',
146
+ 'General-HR': 'BiRefNet_HR',
147
+ 'Matting-HR': 'BiRefNet_HR-matting',
148
+ 'Matting': 'BiRefNet-matting',
149
+ 'Portrait': 'BiRefNet-portrait',
150
+ 'General-reso_512': 'BiRefNet_512x512',
151
+ 'General-Lite': 'BiRefNet_lite',
152
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
153
+ # 'Anime-Lite': 'BiRefNet_lite-Anime',
154
+ 'DIS': 'BiRefNet-DIS5K',
155
+ 'HRSOD': 'BiRefNet-HRSOD',
156
+ 'COD': 'BiRefNet-COD',
157
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
158
+ 'General-legacy': 'BiRefNet-legacy',
159
+ 'General-dynamic': 'BiRefNet_dynamic',
160
+ 'Matting-dynamic': 'BiRefNet_dynamic-matting',
161
+ }
162
+
163
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
164
+ birefnet.to(device)
165
+ birefnet.eval(); birefnet.half()
166
+
167
+
168
+ @spaces.GPU
169
+ def predict(images, resolution, weights_file):
170
+ assert (images is not None), 'AssertionError: images cannot be None.'
171
+
172
+ global birefnet
173
+ # Load BiRefNet with chosen weights
174
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
175
+ print('Using weights: {}.'.format(_weights_file))
176
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
177
+ birefnet.to(device)
178
+ birefnet.eval(); birefnet.half()
179
+
180
+ try:
181
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
182
+ except:
183
+ if weights_file in ['General-HR', 'Matting-HR']:
184
+ resolution = (2048, 2048)
185
+ elif weights_file in ['General-Lite-2K']:
186
+ resolution = (2560, 1440)
187
+ elif weights_file in ['General-reso_512']:
188
+ resolution = (512, 512)
189
+ else:
190
+ if '_dynamic' in weights_file:
191
+ resolution = None
192
+ print('Using the original size (div by 32) for inference.')
193
+ else:
194
+ resolution = (1024, 1024)
195
+ print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
196
+
197
+ if isinstance(images, list):
198
+ # For tab_batch
199
+ save_paths = []
200
+ save_dir = 'preds-BiRefNet'
201
+ if not os.path.exists(save_dir):
202
+ os.makedirs(save_dir)
203
+ tab_is_batch = True
204
+ else:
205
+ images = [images]
206
+ tab_is_batch = False
207
+
208
+ for idx_image, image_src in enumerate(images):
209
+ if isinstance(image_src, str):
210
+ if os.path.isfile(image_src):
211
+ image_ori = Image.open(image_src)
212
+ else:
213
+ response = requests.get(image_src)
214
+ image_data = BytesIO(response.content)
215
+ image_ori = Image.open(image_data)
216
+ else:
217
+ image_ori = Image.fromarray(image_src)
218
+
219
+ image = image_ori.convert('RGB')
220
+ # Preprocess the image
221
+ if resolution is None:
222
+ resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
223
+ if resolution_div_by_32 != resolution:
224
+ resolution = resolution_div_by_32
225
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
226
+ image_proc = image_preprocessor.proc(image)
227
+ image_proc = image_proc.unsqueeze(0)
228
+
229
+ # Prediction
230
+ with torch.no_grad():
231
+ preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
232
+ pred = preds[0].squeeze()
233
+
234
+ # Show Results
235
+ pred_pil = transforms.ToPILImage()(pred)
236
+ image_masked = refine_foreground(image, pred_pil, device=device)
237
+ image_masked.putalpha(pred_pil.resize(image.size))
238
+
239
+ torch.cuda.empty_cache()
240
+
241
+ if tab_is_batch:
242
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
243
+ image_masked.save(save_file_path)
244
+ save_paths.append(save_file_path)
245
+
246
+ if tab_is_batch:
247
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
248
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
249
+ for file in save_paths:
250
+ zipf.write(file, os.path.basename(file))
251
+ return save_paths, zip_file_path
252
+ else:
253
+ return (image_masked, image_ori)
254
+
255
+
256
+ examples = [[_] for _ in glob('examples/*')][:]
257
+ # Add the option of resolution in a text box.
258
+ for idx_example, example in enumerate(examples):
259
+ if 'My_' in example[0]:
260
+ example_resolution = '2048x2048'
261
+ model_choice = 'Matting-HR'
262
+ else:
263
+ example_resolution = '1024x1024'
264
+ model_choice = 'General'
265
+ examples[idx_example] = examples[idx_example] + [example_resolution, model_choice]
266
+
267
+ examples_url = [
268
+ ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
269
+ ]
270
+ for idx_example_url, example_url in enumerate(examples_url):
271
+ examples_url[idx_example_url] = examples_url[idx_example_url] + ['1024x1024', 'General']
272
+
273
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
274
+ ' The resolution used in our training was `1024x1024`, which is the suggested resolution to obtain good results! `2048x2048` is suggested for BiRefNet_HR.\n'
275
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
276
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
277
+
278
+ tab_image = gr.Interface(
279
+ fn=predict,
280
+ inputs=[
281
+ gr.Image(label='Upload an image'),
282
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
283
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
284
+ ],
285
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
286
+ examples=examples,
287
+ api_name="image",
288
+ description=descriptions,
289
+ )
290
+
291
+ tab_text = gr.Interface(
292
+ fn=predict,
293
+ inputs=[
294
+ gr.Textbox(label="Paste an image URL"),
295
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
296
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
297
+ ],
298
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
299
+ examples=examples_url,
300
+ api_name="URL",
301
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
302
+ )
303
+
304
+ tab_batch = gr.Interface(
305
+ fn=predict,
306
+ inputs=[
307
+ gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
308
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
309
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
310
+ ],
311
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
312
+ api_name="batch",
313
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
314
+ )
315
+
316
+ demo = gr.TabbedInterface(
317
+ [tab_image, tab_text, tab_batch],
318
+ ['image', 'URL', 'batch'],
319
+ title="Official Online Demo of BiRefNet",
320
+ )
321
+
322
+ if __name__ == "__main__":
323
+ demo.launch(debug=True, ssr_mode=False)